//===- GridwiseGemmToBlockwise - MLIR Rock ops lowering passes -----===//
//
// Copyright 2020 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ============================================================
//
// This pass converts rock.gridwise_gemm[_v2] into block- and threadwise ops
//
//===-----------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Rock/IR/AmdArchDb.h"
#include "mlir/Dialect/Rock/IR/GetRockInfo.h"
#include "mlir/Dialect/Rock/IR/Rock.h"
#include "mlir/Dialect/Rock/IR/TransformMapBuilder.h"
#include "mlir/Dialect/Rock/Passes.h"
#include "mlir/Dialect/Rock/Tuning/GeneralGemmBlockStructure.h"
#include "mlir/Dialect/Rock/Tuning/GridwiseGemmParams.h"
#include "mlir/Dialect/Rock/utility/builderUtils.h"
#include "mlir/Dialect/Rock/utility/loweringUtils.h"
#include "mlir/Dialect/Rock/utility/math.h"
#include "mlir/Dialect/Rock/utility/transformMapUtils.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"

#include "GridLayoutEmitter.h"
#include "mlir/Dialect/Rock/IR/AccelEmitter.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>

namespace mlir {
namespace rock {
#define GEN_PASS_DEF_ROCKGRIDWISEGEMMTOBLOCKWISEPASS
#include "mlir/Dialect/Rock/Passes.h.inc"
} // namespace rock
} // namespace mlir

#define DEBUG_TYPE "rock-gridwise-to-blockwise"

using namespace mlir;
using namespace mlir::arith;
using namespace mlir::rock;
using mlir::gpu::AddressSpace;

namespace {
struct RockGridwiseGemmToBlockwisePass
    : public rock::impl::RockGridwiseGemmToBlockwisePassBase<
          RockGridwiseGemmToBlockwisePass> {
  void runOnOperation() override;
};

} // end anonymous namespace

/// Given a copy layout <copyDPerThread, copyKPerThread>, come up with the best
/// vectorization strategy for the layout. For instance, if the layout is <D,K>
/// = <2,16> and K is contiguous, we will vectorize by 16 along K and we will
/// loop over the other dimension
static std::pair<GemmDimension, int64_t>
bestGlobalVectorization(OpBuilder &b, Value matrix, int64_t copyDPerThread,
                        int64_t copyKPerThread, GemmDimension tiebreaker,
                        int64_t kPerBlock, int64_t dPerBlock) {
  // A future commit will account for the underlying buffer's vectorization
  // here.
  VectorizationResult kVectorRes = getMaxVectorization(
      matrix, static_cast<uint32_t>(GemmDimension::K), /*inputDimLen=*/
      math_util::gcd(copyKPerThread * copyDPerThread, kPerBlock),
      matrix.getDefiningOp());
  int64_t kVectorLen = kVectorRes.max;
  VectorizationResult dVectorRes = getMaxVectorization(
      matrix, static_cast<uint32_t>(GemmDimension::MorN), /*inputDimLen=*/
      math_util::gcd(copyDPerThread * copyKPerThread, dPerBlock),
      matrix.getDefiningOp());
  int64_t dVectorLen = dVectorRes.max;

  if (kVectorLen > dVectorLen) {
    kVectorLen = math_util::gcd(kVectorLen, copyKPerThread);
    return {GemmDimension::K, kVectorLen};
  }

  if (dVectorLen > kVectorLen) {
    dVectorLen = math_util::gcd(dVectorLen, copyDPerThread);
    return {GemmDimension::MorN, dVectorLen};
  }

  return {tiebreaker, kVectorLen};
}

/// Compute a thread copy layout, i.e., how many elements a single thread (or
/// workitem) reads along K and M (independently on how we vectorize the reads)
static FailureOr<std::pair<int64_t, int64_t>>
computeCopyPerThread(Type elementType, int64_t copyPerThread, int64_t kPerBlock,
                     int64_t dPerBlock, int64_t kpack, Location loc) {

  // By default, we try to maximize the LDS store vectorization. So we will try
  // to read as many elements as possible along the contiguous dimension in LDS
  // and `copyPerThread/elements` in the other dimension
  int64_t maxVlen = 128 / elementType.getIntOrFloatBitWidth();
  int64_t copyKPerThread = 0;
  int64_t copyDPerThread = 0;

  if (kpack == 1) {
    copyDPerThread = math_util::gcd(maxVlen, copyPerThread);
    copyKPerThread = copyPerThread / copyDPerThread;
  } else {
    copyKPerThread = math_util::gcd(maxVlen, copyPerThread);
    copyDPerThread = copyPerThread / copyKPerThread;
  }

  if (copyKPerThread == 0 || copyDPerThread == 0) {
    return emitError(loc) << "gemmA copy size too small,"
                          << " copyKPerThread: " << copyKPerThread
                          << " copyDPerThread: " << copyDPerThread << "\n";
  }
  if (kPerBlock < copyKPerThread || dPerBlock < copyDPerThread) {
    return mlir::emitError(loc)
           << "gemmA per thread copy smaller than per"
           << " block copy, incohereant tuning parameters\n";
  }
  return std::make_pair(copyKPerThread, copyDPerThread);
}

/// Wraps the LDS buffer "buffer", which is <kOuter * d * kpack *
/// sizeof(T) x i8> into a tid x iter view, where `iter` iterates over nominal
/// scalar indices into a buffer of type T. `buffer` will be reinterpreted as a
/// buffer with element type vector<kpackPerThread x T> (with kpackPerThread ==
/// 1 meaning just T). The resulting view must be iterated over with a stride of
/// no less than min(kPerThread, kpack). Also note that the `d` dimension
/// might be rotated to minimize bank conflicts (i.e., depending on
/// `rotateDWithK`
// we can apply a transformation similar to `d=(d+kOuter)%D`)
static FailureOr<Value> wrapLDSBufferForStore(OpBuilder &b, Location loc,
                                              Value buffer, Type ldsReadType,
                                              int64_t kOuter, StringRef dName,
                                              int64_t d, int64_t kPerThread,
                                              int64_t dPerThread,
                                              bool rotateDWithK = false) {
  MemRefType bufferType = cast<MemRefType>(buffer.getType());
  ArrayRef<int64_t> bufferShape = bufferType.getShape();
  Type dataType = ldsReadType;
  if (bufferShape.size() != 1)
    return emitError(loc, "Expected a flat buffer");
  int64_t kpack = 1;
  if (auto vectorDataType = dyn_cast<VectorType>(dataType)) {
    kpack = vectorDataType.getNumElements();
    dataType = vectorDataType.getElementType();
  }

  if (bufferShape[0] != kOuter * d * kpack * getByteWidth(dataType)) {
    return emitError(loc, "LDS buffer should have ")
           << kOuter * d * kpack * getByteWidth(dataType)
           << " elements but has " << bufferShape[0];
  }
  int64_t kpackPerThread = std::min(kPerThread, kpack);
  assert(kpack % kpackPerThread == 0);
  int64_t threadsPerKpack = kpack / kpackPerThread;

  Type ldsWriteType = vectorTypeOrSelf(dataType, kpackPerThread);
  auto typedBuffer = viewBufferAs(b, buffer, ldsWriteType);

  TopDownTMBuilder mergeKpack{
      b, {"k", "d"}, {kOuter * threadsPerKpack * kpackPerThread, d}};
  mergeKpack.merge({"k_outer", "kpack_idx", "kpack_vec"}, {0, 2, 3}, "k",
                   {kOuter, threadsPerKpack, kpackPerThread});
  mergeKpack.merge({dName}, {1}, "d", {d});

  TransformMapAttr mergeKpackAttr = mergeKpack.get();
  SmallVector<Attribute> transformAttrs{mergeKpackAttr};

  // Rotate the buffer if necessary to minimize bank conflicts. Rotating the
  // buffer has the benefit of minimizing bank conflicts when we are transposing
  // the matrix from global to LDS. I.e., instead of storing different items in
  // position (0,0), (1,0), (2,0), ... we store it in (0,0), (1,1), (2, 2), ...
  int64_t stride = (kpack == 1 ? dPerThread : 1);
  TopDownTMBuilder reshapeBuf = rotateIf(
      rotateDWithK, mergeKpack, mergeKpackAttr, stride, dName, d, 1, "k_outer",
      kOuter, {"k_outer"}, {"kpack_idx", "kpack_vec"}, transformAttrs);

  reshapeBuf.unmerge("raw", 0, {"k_outer", dName, "kpack_idx"},
                     {kOuter, d, threadsPerKpack});
  reshapeBuf.ignore("kpack_vec");
  TransformMapAttr reshapeBufAttr = reshapeBuf.get();
  transformAttrs.push_back(reshapeBufAttr);

  ArrayAttr asMatrix = b.getArrayAttr(transformAttrs);
  return transform(b, typedBuffer, asMatrix);
}

static LogicalResult checkLDSSize(Operation *op, int64_t aBufferBytes,
                                  int64_t bBufferBytes) {
  int64_t ldsBytes = aBufferBytes + bBufferBytes;
  // Check for arch limitations exceeded
  FailureOr<StringAttr> maybeArch = getArch(op);
  if (succeeded(maybeArch)) {
    StringAttr arch = maybeArch.value();
    const int64_t ldsSize = rock::lookupArchInfo(arch).maxSharedMemPerWG;
    return success(ldsBytes <= ldsSize);
  }
  return success();
}

// Following structures holds knobs to tweak the
// the LDS layout for gemms/attention ops.
struct LDSLayoutConfigDim {
  bool doRotateWithK;
  bool doSwapThreadIterSubDims;
};

// This is helper struct to aggregate
// derived information w.r.t load vectorization
struct VectorDimInfo {
  GemmDimension vectorDim;
  int64_t vectorLen;
  int64_t inKPerThread;
  int64_t inDPerThread;
  GemmDimension vectorTiebreaker;
};

static FailureOr<VectorDimInfo> getVectorDim(PatternRewriter &rewriter,
                                             Location loc, Value matrix,
                                             Type elemType, int64_t blockSize,
                                             int64_t kPerBlock,
                                             int64_t dPerBlock, int64_t kpack) {
  int64_t copyPerThread = (kPerBlock * dPerBlock) / blockSize;
  auto maybeCopyDPerThread = computeCopyPerThread(
      elemType, copyPerThread, kPerBlock, dPerBlock, kpack, loc);
  if (failed(maybeCopyDPerThread))
    return failure();

  int64_t copyKPerThread = (*maybeCopyDPerThread).first;
  int64_t copyDPerThread = (*maybeCopyDPerThread).second;
  // Find the best way of vectorizing the layout
  GemmDimension vectorTiebreaker =
      (kpack > 1) ? GemmDimension::K : GemmDimension::MorN;
  int64_t vectorLen;
  GemmDimension vectorDim;
  std::tie(vectorDim, vectorLen) =
      bestGlobalVectorization(rewriter, matrix, copyDPerThread, copyKPerThread,
                              vectorTiebreaker, kPerBlock, dPerBlock);
  return VectorDimInfo{vectorDim, vectorLen, copyKPerThread, copyDPerThread,
                       vectorTiebreaker};
}

static LDSLayoutConfigDim
getLDSLayoutConfigDim(Type elementType, int64_t kpack,
                      const VectorDimInfo &vecDimInfo) {
  LDSLayoutConfigDim cfg;
  int64_t maxVlen = 128 / elementType.getIntOrFloatBitWidth();
  int64_t copyDPerThread = vecDimInfo.inDPerThread;
  bool isKContigousDim = vecDimInfo.vectorDim == GemmDimension::K;
  // If kpack is less than the hardware max vector length, and we are
  // writing more contiguous kpack elements, there is a possibility to
  // vectorize that we want to preserve (i.e., we favour vectorization over
  // bank conflicts resolution)
  bool isPossibleToVectorizeD = (kpack < maxVlen && copyDPerThread > 1);
  cfg.doRotateWithK = isKContigousDim && !isPossibleToVectorizeD;
  cfg.doSwapThreadIterSubDims = !isKContigousDim && !isPossibleToVectorizeD;
  LLVM_DEBUG(llvm::dbgs() << "rotateWithK: " << cfg.doRotateWithK << "\n"
                          << "doSwapThreadIterSubDimsForM: "
                          << cfg.doSwapThreadIterSubDims << "\n");
  return cfg;
}

//===----------------------------------------------------------------------===//
// GridwiseGemm lowering.
//===----------------------------------------------------------------------===//

namespace {
struct GridwiseGemmRewritePattern : public OpRewritePattern<GridwiseGemmOp> {
  using OpRewritePattern<GridwiseGemmOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(GridwiseGemmOp op,
                                PatternRewriter &b) const override {
    Location loc = op.getLoc();

    // Obtain data type.
    Type elementTypeA = op.getA().getType().getElementType();
    Type elementTypeB = op.getB().getType().getElementType();
    Type destType = op.getC().getType().getElementType();

    // Prepare some useful constants.
    Value zeroConstantFloatOp = createZeroConstantOp(b, loc, destType);
    auto zeroConstantOp = ConstantIndexOp::create(b, loc, 0);

    ArrayRef<int64_t> aShape, bShape, cShape;
    aShape = op.getA().getType().getShape();
    bShape = op.getB().getType().getShape();
    cShape = op.getC().getType().getShape();
    // Obtain critical matrix dimensions.
    int64_t G = aShape[0];
    int64_t K = aShape[1];
    int64_t M = aShape[2];
    int64_t N = bShape[2];

    if (bShape[0] != G || cShape[0] != G) {
      return op.emitOpError("Mismatched G dimensions in matrix multiply;")
             << " A[0] = " << G << " b[0] = " << bShape[0]
             << " C[0] = " << cShape[0];
    }
    if (cShape[1] != M) {
      return op.emitOpError("Mismatched M dimensions in matrix multiply:")
             << " A[2] = " << M << " C[1] = " << cShape[1];
    }
    if (bShape[1] != K) {
      return op.emitOpError("Mismatched K dimensions in matrix multiply:")
             << " A[1] = " << K << " B[1] = " << bShape[1];
    }

    if (cShape[2] != N) {
      return op.emitOpError("Mismatched N dimensions in matrix multiply:")
             << " B[2] = " << N << " C[2] = " << cShape[2];
    }

    // Obtain critical tuning parameters.
    uint32_t gridSize = op.getGridSize();
    GeneralGemmParamsAttr tuningParams = op.getParams();
    int64_t kpack = tuningParams.getKpack();
    // TODO: kPerBlock, as defined in parameter selection etc,
    // is in units of kPack, not individual k. This should be changed
    // at some future point, but it'll be worked around for now.
    uint32_t blockSize = tuningParams.getBlockSize();
    int64_t kpacksPerBlock = tuningParams.getKPerBlock();
    int64_t mPerBlock = tuningParams.getMPerBlock();
    int64_t nPerBlock = tuningParams.getNPerBlock();
    int64_t mPerThread = tuningParams.getMPerThread();
    int64_t nPerThread = tuningParams.getNPerThread();

    GeneralGemmBlockStructure blockStructure =
        *deriveGeneralGemmBlockStructure(blockSize);
    int64_t mThreadsPerCuwave = blockStructure.mThreadsPerCuwave;
    int64_t nThreadsPerCuwave = blockStructure.nThreadsPerCuwave;
    int64_t mCuwavesPerBlock = blockStructure.mCuwavesPerBlock;
    int64_t nCuwavesPerBlock = blockStructure.nCuwavesPerBlock;

    int64_t kPerBlock = kpacksPerBlock * kpack;

    bool useIndexDiffs = true;

    int64_t mBlocks = M / mPerBlock;
    int64_t nBlocks = N / nPerBlock;

    LLVM_DEBUG(llvm::dbgs() << "\ngridwise_gemm op:\n");
    LLVM_DEBUG(op.print(llvm::dbgs()));
    LLVM_DEBUG(llvm::dbgs() << "\n");

    LLVM_DEBUG(llvm::dbgs()
               << "M: " << M << "\n"
               << "N: " << N << "\n"
               << "K: " << K << "\n"
               << "G: " << G << "\n"
               << "blockSize: " << blockSize << "\n"
               << "mPerBlock: " << mPerBlock << "\n"
               << "mBlocks = M / mPerBlock: " << mBlocks << "\n"
               << "nPerBlock: " << nPerBlock << "\n"
               << "nBlocks = N / nPerBlock: " << nBlocks << "\n"
               << "kPerBlock: " << kPerBlock << "\n"
               << "kpack: " << kpack << "\n"
               << "mPerThread: " << mPerThread << "\n"
               << "nPerThread: " << nPerThread << "\n"
               << "mThreadsPerCuwave: " << mThreadsPerCuwave << "\n"
               << "mCuwavesPerBlock: " << mCuwavesPerBlock << "\n"
               << "nThreadsPerCuwave: " << nThreadsPerCuwave << "\n"
               << "nCuwavesPerBlock: " << nCuwavesPerBlock << "\n");

    // Compute required LDS sizes.
    int64_t ldsBlockASize =
        kpacksPerBlock * mPerBlock * kpack * getByteWidth(elementTypeA);
    int64_t ldsBlockBSize =
        kpacksPerBlock * nPerBlock * kpack * getByteWidth(elementTypeB);
    LLVM_DEBUG(llvm::dbgs() << "LDS block size (in bytes):" << ldsBlockASize
                            << " " << ldsBlockBSize << "\n");
    if (failed(checkLDSSize(op, ldsBlockASize, ldsBlockBSize)))
      return op.emitOpError("requires too much LDS");

    // Allocate LDS.
    auto workgroupMemoryAddressSpace = b.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getWorkgroupAddressSpace());
    auto ldsMemRefAType =
        MemRefType::get({ldsBlockASize}, b.getI8Type(), AffineMap{},
                        workgroupMemoryAddressSpace);
    auto ldsByteBufferA = GpuAllocOp::create(b, loc, ldsMemRefAType);
    auto ldsMemRefBType =
        MemRefType::get({ldsBlockBSize}, b.getI8Type(), AffineMap{},
                        workgroupMemoryAddressSpace);
    auto ldsByteBufferB = GpuAllocOp::create(b, loc, ldsMemRefBType);

    // Alloc for Matrix C on registers.
    // Compute register size from attributes.

    int64_t gemmMRepeat =
        mPerBlock / (mPerThread * mThreadsPerCuwave * mCuwavesPerBlock);
    int64_t gemmNRepeat =
        nPerBlock / (nPerThread * nThreadsPerCuwave * nCuwavesPerBlock);

    LLVM_DEBUG(llvm::dbgs() << "GemmMRepeat: " << gemmMRepeat << "\n");
    LLVM_DEBUG(llvm::dbgs() << "GemmNRepeat: " << gemmNRepeat << "\n");

    int64_t threadCNumM = gemmMRepeat * mPerThread;
    int64_t threadCNumN = gemmNRepeat * nPerThread;
    int64_t threadCNumRegisters = threadCNumM * threadCNumN;
    auto privateMemoryAddressSpace = b.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    auto threadCRegisterMemRefType =
        MemRefType::get({threadCNumRegisters}, destType, AffineMap{},
                        privateMemoryAddressSpace);
    Value registerMatrixCAllocOp =
        GpuAllocOp::create(b, loc, threadCRegisterMemRefType);
    Value registerMatrixCViewOp = reshapeBuffer(
        b, loc, registerMatrixCAllocOp, {"m", "n"}, {threadCNumM, threadCNumN});

    // Zero init Matrix C on registers.
    FillOp::create(b, loc, registerMatrixCAllocOp, zeroConstantFloatOp);

    // Get current workgroup ID.
    auto bid = WorkgroupIdOp::create(b, loc, b.getIndexType());
    // Get current workitem ID.
    auto tid = WorkitemIdOp::create(b, loc, b.getIndexType());

    if (!isValidBlockSize(blockSize, kPerBlock, mPerBlock, nPerBlock)) {
      return emitError(loc) << "Block size too large, rejecting as invalid.\n";
    }

    int64_t aCopyPerThread = (kPerBlock * mPerBlock) / blockSize;
    int64_t bCopyPerThread = (kPerBlock * nPerBlock) / blockSize;

    FailureOr<VectorDimInfo> maybeVecDimInfoA =
        getVectorDim(b, loc, op.getA(), elementTypeA, blockSize, kPerBlock,
                     mPerBlock, kpack);
    if (failed(maybeVecDimInfoA)) {
      return failure();
    }
    FailureOr<VectorDimInfo> maybeVecDimInfoB =
        getVectorDim(b, loc, op.getB(), elementTypeB, blockSize, kPerBlock,
                     nPerBlock, kpack);
    if (failed(maybeVecDimInfoB)) {
      return failure();
    }
    LLVM_DEBUG(llvm::dbgs()
               << "aCopyPerThread: " << aCopyPerThread << "\n"
               << "bCopyPerThread: " << bCopyPerThread << "\n"
               << "aVectorDim: " << maybeVecDimInfoA->vectorDim << "\n"
               << "aVectorLen: " << maybeVecDimInfoA->vectorLen << "\n"
               << "bVectorDim: " << maybeVecDimInfoB->vectorDim << "\n"
               << "bVectorLen: " << maybeVecDimInfoB->vectorLen << "\n"
               << "vectorTiebreaker: " << maybeVecDimInfoA->vectorTiebreaker
               << "\n");
    SmallVector<int64_t, 3> bidGridLengths = {G, mBlocks, nBlocks};
    SmallVector<StringRef, 3> bidGridOrder = {"g_block", "m_block", "n_block"};
    FailureOr<RegsAsMatrixSubTiles> maybeABufferViews = getLoadRegsAsTileViews(
        b, loc, op.getA(), "m", bidGridOrder, bidGridLengths, blockSize,
        kPerBlock, mPerBlock, maybeVecDimInfoA->inKPerThread,
        maybeVecDimInfoA->inDPerThread,
        maybeVecDimInfoA->vectorDim == GemmDimension::K);
    if (failed(maybeABufferViews)) {
      return failure();
    }
    Value wrappedA = transform(b, op.getA(), maybeABufferViews->gridSubTile);
    FailureOr<RegsAsMatrixSubTiles> maybeBBufferViews = getLoadRegsAsTileViews(
        b, loc, op.getB(), "n", bidGridOrder, bidGridLengths, blockSize,
        kPerBlock, nPerBlock, maybeVecDimInfoB->inKPerThread,
        maybeVecDimInfoB->inDPerThread,
        maybeVecDimInfoB->vectorDim == GemmDimension::K);
    if (failed(maybeBBufferViews)) {
      return failure();
    }
    Value wrappedB = transform(b, op.getB(), maybeBBufferViews->gridSubTile);

    auto makeRegs = [&](int64_t len, Type elementType) -> GpuAllocOp {
      Type allocType = MemRefType::get({len}, elementType, AffineMap{},
                                       privateMemoryAddressSpace);
      return GpuAllocOp::create(b, loc, allocType);
    };
    GpuAllocOp loadBufferA = makeRegs(aCopyPerThread, elementTypeA);
    GpuAllocOp loadBufferB = makeRegs(bCopyPerThread, elementTypeB);

    // Compute grid coordinates
    FailureOr<mlir::StringAttr> maybeArch = getArch(op);
    if (failed(maybeArch)) {
      return op.emitError("arch needs to be set.");
    }
    auto gridCoords = layout::makeGroupedGridLayout(
        b, loc, bid,
        {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
        maybeArch->getValue());

    Value storeBufferA = GpuAllocOp::create(b, loc, loadBufferA.getType());
    Value storeBufferB = GpuAllocOp::create(b, loc, loadBufferB.getType());

    LDSLayoutConfigDim ldsLayoutConfigA =
        getLDSLayoutConfigDim(elementTypeA, kpack, maybeVecDimInfoA.value());
    LDSLayoutConfigDim ldsLayoutConfigB =
        getLDSLayoutConfigDim(elementTypeB, kpack, maybeVecDimInfoB.value());

    // We invert the transforms that are iter --> K x D slice of the tensor
    // so that we can view loadBuffer as a K x D tensor
    ArrayAttr loadBufferAViews =
        invertTransforms(b, loc, maybeABufferViews->threadSubTile);
    Value viewLoadBufferA = transform(b, loadBufferA, loadBufferAViews);
    // Prior to LDS store, we need re-arrange register buffer to maxmize LDS
    // vectorization Hence, creating the view w.r.t global that correspond to
    // such re-arranged register buffer
    FailureOr<RegsAsMatrixSubTiles> maybeALdsStoreViews =
        getPackedRegsAsTileViews(
            b, loc, op.getA(), "m", bidGridOrder, bidGridLengths, blockSize,
            kPerBlock, mPerBlock, maybeVecDimInfoA->inKPerThread,
            maybeVecDimInfoA->inDPerThread, kpack,
            maybeVecDimInfoA->vectorDim == GemmDimension::K,
            ldsLayoutConfigA.doSwapThreadIterSubDims);
    if (failed(maybeALdsStoreViews)) {
      return failure();
    }
    ArrayAttr storeBufferAViews =
        invertTransforms(b, loc, maybeALdsStoreViews->threadSubTile);
    Value viewStoreBufferA = transform(b, storeBufferA, storeBufferAViews);
    ArrayAttr loadBufferBViews =
        invertTransforms(b, loc, maybeBBufferViews->threadSubTile);
    Value viewLoadBufferB = transform(b, loadBufferB, loadBufferBViews);
    // Prior to LDS store, we need re-arrange register buffer to maxmize LDS
    // vectorization Hence, creating the view w.r.t global that correspond to
    // such re-arranged register buffer
    FailureOr<RegsAsMatrixSubTiles> maybeBLdsStoreViews =
        getPackedRegsAsTileViews(
            b, loc, op.getB(), "n", bidGridOrder, bidGridLengths, blockSize,
            kPerBlock, nPerBlock, maybeVecDimInfoB->inKPerThread,
            maybeVecDimInfoB->inDPerThread, kpack,
            maybeVecDimInfoB->vectorDim == GemmDimension::K,
            ldsLayoutConfigB.doSwapThreadIterSubDims);
    if (failed(maybeBLdsStoreViews)) {
      return failure();
    }
    ArrayAttr storeBufferBViews =
        invertTransforms(b, loc, maybeBLdsStoreViews->threadSubTile);
    Value viewStoreBufferB = transform(b, storeBufferB, storeBufferBViews);

    Type ldsReadTypeA = vectorTypeOrSelf(elementTypeA, kpack);
    FailureOr<Value> maybeWrappedLdsA = wrapLDSBufferForStore(
        b, loc, ldsByteBufferA, ldsReadTypeA, kpacksPerBlock, "m", mPerBlock,
        maybeVecDimInfoA->inKPerThread, maybeVecDimInfoA->inDPerThread,
        ldsLayoutConfigA.doRotateWithK);
    if (failed(maybeWrappedLdsA))
      return maybeWrappedLdsA;
    // This is KxD view of the flat LDS buffer
    Value wrappedLdsA = std::move(*maybeWrappedLdsA);
    // This will produce a (tid, iter) --> flat LDS view
    wrappedLdsA = transform(b, wrappedLdsA, maybeALdsStoreViews->blockSubTile);

    Type ldsReadTypeB = vectorTypeOrSelf(elementTypeB, kpack);
    FailureOr<Value> maybeWrappedLdsB = wrapLDSBufferForStore(
        b, loc, ldsByteBufferB, ldsReadTypeB, kpacksPerBlock, "n", nPerBlock,
        maybeVecDimInfoB->inKPerThread, maybeVecDimInfoB->inDPerThread,
        ldsLayoutConfigB.doRotateWithK);
    if (failed(maybeWrappedLdsB))
      return maybeWrappedLdsB;
    // This is KxD view of the flat LDS buffer
    Value wrappedLdsB = std::move(*maybeWrappedLdsB);
    // This will produce a (tid, iter) --> flat LDS view
    wrappedLdsB = transform(b, wrappedLdsB, maybeBLdsStoreViews->blockSubTile);

    // The blockwise gemm isn't set up for vector-of-kpack loads and so expects
    // a scalar kpacksPerBlock x dPerBlock x kpack x T buffer unconditionally.
    Value ldsMatrixA = viewBufferAs(b, ldsByteBufferA, elementTypeA);
    ldsMatrixA = reshapeBuffer(b, loc, ldsMatrixA, {"k", "m", "kpack"},
                               {kpacksPerBlock, mPerBlock, kpack});
    Value ldsMatrixB = viewBufferAs(b, ldsByteBufferB, elementTypeB);
    ldsMatrixB = reshapeBuffer(b, loc, ldsMatrixB, {"k", "n", "kpack"},
                               {kpacksPerBlock, nPerBlock, kpack});

    // Emit loop.
    Value nIterations = ConstantIndexOp::create(b, loc, K / kPerBlock);
    Value step = ConstantIndexOp::create(b, loc, 1);
    BlockwiseGemmOp blockwiseGemmOp;

    auto loopOp = scf::ForOp::create(b, loc, zeroConstantOp, nIterations, step);
    loopOp->setAttr(PipelineAttr::getMnemonic(),
                    rock::PipelineAttr::get(b.getContext(), 2));
    {
      // inside the loop.
      PatternRewriter::InsertionGuard guard(b);
      b.setInsertionPointToStart(loopOp.getBody());

      Value iv = loopOp.getInductionVar();

      auto stage0 = StageOp::create(b, loc, "GlobalRead");
      {
        PatternRewriter::InsertionGuard guard(b);
        b.setInsertionPointToStart(&stage0.getRegion().emplaceBlock());

        ThreadwiseReadIntoOp::create(
            b, loc, vectorOfBoolShapedLike(loadBufferA), wrappedA, loadBufferA,
            /*dynamicValidities=*/ValueRange{},
            /*extraViews=*/b.getArrayAttr({}),
            /*extraIndices=*/
            ValueRange{/*kIter=*/iv, gridCoords.g_block, gridCoords.m_block,
                       gridCoords.n_block, tid},
            true, true);
        ThreadwiseReadIntoOp::create(
            b, loc, vectorOfBoolShapedLike(loadBufferB), wrappedB, loadBufferB,
            /*dynamicValidities=*/ValueRange{},
            /*extraViews=*/b.getArrayAttr({}),
            /*extraIndices=*/
            ValueRange{/*kIter=*/iv, gridCoords.g_block, gridCoords.m_block,
                       gridCoords.n_block, tid},
            true, true);
        rock::YieldOp::create(b, loc);
      }

      auto stage1 = StageOp::create(b, loc, "LDSWrite");
      {
        PatternRewriter::InsertionGuard guard(b);
        b.setInsertionPointToStart(&stage1.getRegion().emplaceBlock());

        ThreadwiseCopyOp::create(b, loc, viewLoadBufferA, ValueRange{},
                                 viewStoreBufferA, ValueRange{}, useIndexDiffs,
                                 true);
        ThreadwiseCopyOp::create(b, loc, viewLoadBufferB, ValueRange{},
                                 viewStoreBufferB, ValueRange{}, useIndexDiffs,
                                 true);

        ThreadwiseWriteAllOp::create(b, loc, storeBufferA, wrappedLdsA,
                                     /*extraViews=*/b.getArrayAttr({}),
                                     /*extraIndices=*/ValueRange{tid},
                                     StoreMethod::Set,
                                     /*forceUnroll=*/true,
                                     /*useIndexDiffs=*/true);
        ThreadwiseWriteAllOp::create(b, loc, storeBufferB, wrappedLdsB,
                                     /*extraViews=*/b.getArrayAttr({}),
                                     /*extraIndices=*/ValueRange{tid},
                                     StoreMethod::Set,
                                     /*forceUnroll=*/true,
                                     /*useIndexDiffs=*/true);

        rock::YieldOp::create(b, loc);
      }

      auto stage2 = StageOp::create(b, loc, "MMA");
      {
        PatternRewriter::InsertionGuard guard(b);
        b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());

        // Emit blockwise GEMM.
        blockwiseGemmOp = BlockwiseGemmOp::create(
            b, loc, ldsMatrixA, ldsMatrixB, registerMatrixCViewOp,
            b.getI32IntegerAttr(maybeVecDimInfoA->inDPerThread),
            b.getI32IntegerAttr(maybeVecDimInfoB->inDPerThread),
            ldsLayoutConfigA.doRotateWithK ? b.getUnitAttr() : nullptr,
            ldsLayoutConfigB.doRotateWithK ? b.getUnitAttr() : nullptr,
            op.getParamsAttr());
        rock::YieldOp::create(b, loc);
      }
    }

    // the LDS allocated to load A and B matrices won't be used anymore
    GpuDeallocOp::create(b, loc, ldsByteBufferA);
    GpuDeallocOp::create(b, loc, ldsByteBufferB);

    SmallVector<Attribute> transformAttrs;

    // Threadwise copy from register (naive tensor) to global (generic tensor).
    TopDownTMBuilder splitMemoryCoords(
        b, {"g_block", "m_block", "n_block", "tid", "iter"},
        {gridSize, mBlocks, nBlocks, blockSize, threadCNumRegisters}, loc);
    splitMemoryCoords.passThrough({"g_block", "m_block", "n_block"});
    splitMemoryCoords.merge({"m_cuwaves", "n_cuwaves", "m_cuwave", "n_cuwave"},
                            {3, 4, 5, 6}, "tid",
                            {mCuwavesPerBlock, nCuwavesPerBlock,
                             mThreadsPerCuwave, nThreadsPerCuwave});
    splitMemoryCoords.merge({"m_repeat", "m_thread", "n_repeat", "n_thread"},
                            {7, 8, 9, 10}, "iter",
                            {gemmMRepeat, mPerThread, gemmNRepeat, nPerThread});
    TransformMapAttr splitMemoryCoordsAttr = splitMemoryCoords.get();
    transformAttrs.push_back(splitMemoryCoordsAttr);

    auto toMatrixC =
        TopDownTMBuilder::below(splitMemoryCoords, splitMemoryCoordsAttr);
    toMatrixC.passThrough({"g_block", "m_block", "n_block"});
    toMatrixC.unmerge(
        "gemmBlockM", 3, {"m_repeat", "m_cuwaves", "m_cuwave", "m_thread"},
        {gemmMRepeat, mCuwavesPerBlock, mThreadsPerCuwave, mPerThread});
    toMatrixC.unmerge(
        "gemmBlockN", 4, {"n_repeat", "n_cuwaves", "n_cuwave", "n_thread"},
        {gemmNRepeat, nCuwavesPerBlock, nThreadsPerCuwave, nPerThread});

    swapThreadIdAndIteration(
        toMatrixC, /*mBlocks=*/bidGridLengths[1],
        /*nBlocks=*/bidGridLengths[2], maybeVecDimInfoA->inDPerThread,
        maybeVecDimInfoB->inDPerThread, mPerBlock, nPerBlock,
        ldsLayoutConfigA.doSwapThreadIterSubDims,
        ldsLayoutConfigB.doSwapThreadIterSubDims,
        /*isBlockwise=*/false, transformAttrs);

    Value registerC = registerMatrixCAllocOp;
    ArrayAttr idToMatrixCMaps = b.getArrayAttr(transformAttrs);
    ThreadwiseWriteAllOp::create(b, loc, registerC, op.getC(), idToMatrixCMaps,
                                 /*extraIndices=*/
                                 ValueRange{gridCoords.g_block,
                                            gridCoords.m_block,
                                            gridCoords.n_block, tid},
                                 op.getStoreMethod(),
                                 /*forceUnroll=*/true, useIndexDiffs);
    b.eraseOp(op);

    return success();
  }
};

//===----------------------------------------------------------------------===//
// GridwiseAttentionAccel lowering.
//===----------------------------------------------------------------------===//
struct ElementwiseMultOp {
  using Float = arith::MulFOp;
  using Int = arith::MulIOp;
};

struct ElementwiseAddOp {
  using Float = arith::AddFOp;
  using Int = arith::AddIOp;
};

struct GridwiseAttentionAccelRewritePattern
    : public OpRewritePattern<GridwiseAttentionAccelOp> {
  using OpRewritePattern<GridwiseAttentionAccelOp>::OpRewritePattern;

  LogicalResult storeGemmInputTile(
      PatternRewriter &rewriter, Location loc, int64_t kpack, Value regBuffer,
      RegsAsMatrixSubTiles toLDSViews, Value storeBuffer,
      Value ldsTileByteBuffer, int64_t kpacksPerBlock, StringRef nonKDimName,
      int64_t kPerBlock, int64_t dPerBlock, int64_t copyKPerThread,
      int64_t copyDPerThread, bool forceUnroll, bool rotateDWithK,
      bool barrierBeforeWrite) const {
    Type elemType = cast<MemRefType>(regBuffer.getType()).getElementType();
    ArrayAttr storeBufferViews =
        invertTransforms(rewriter, loc, toLDSViews.threadSubTile);
    Value viewStoreBuffer = transform(rewriter, storeBuffer, storeBufferViews);
    // The following is fine for software pipelining optimization as it could be
    // considered "compute". In future, consider refactoring the following loop
    // to be a single reg->reg op avoid verbose IR at this level.
    ThreadwiseCopyOp::create(rewriter, loc, regBuffer, ValueRange{},
                             viewStoreBuffer, ValueRange{}, false, false);
    Type ldsReadType = vectorTypeOrSelf(elemType, kpack);
    FailureOr<Value> maybeWrappedLds = wrapLDSBufferForStore(
        rewriter, loc, ldsTileByteBuffer, ldsReadType, kpacksPerBlock,
        nonKDimName, dPerBlock, copyKPerThread, copyDPerThread, rotateDWithK);
    if (failed(maybeWrappedLds)) {
      return failure();
    }
    // This is KxD view of the flat LDS buffer
    Value wrappedLds = std::move(*maybeWrappedLds);
    // This will produce a (tid, iter) --> flat LDS view
    wrappedLds = transform(rewriter, wrappedLds, toLDSViews.blockSubTile);
    auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());

    // add LDS barrier to avoid write before load of previous iteration is done
    if (barrierBeforeWrite)
      LDSBarrierOp::create(rewriter, loc);

    ThreadwiseWriteAllOp::create(rewriter, loc, storeBuffer, wrappedLds,
                                 /*extraViews=*/rewriter.getArrayAttr({}),
                                 /*extraIndices=*/ValueRange{tid},
                                 StoreMethod::Set, forceUnroll, true);
    return success();
  }

  // This function will process a tile of gemm input into LDS buffer
  // in a way it could be fed to blockwise_gemm_accel op
  LogicalResult loadAndStoreGemmInputTile(
      Location loc, Value in, Value kIter, Type elemType,
      rock::layout::GridCoordinates gridCoords, Value fromGlobalRegBuffer,
      Value toLDSRegBuffer, Value destBuffer, StringRef nonKDimName,
      int64_t kpack, int64_t kpacksPerBlock, int64_t dPerBlock,
      uint32_t blockSize, uint32_t gridSize, ArrayRef<StringRef> bidGridOrder,
      ArrayRef<int64_t> bidGridLengths, bool forceUnroll,
      PatternRewriter &rewriter, const accel::AccelEmitter &accelEmitter,
      LDSLayoutConfigDim ldsLayoutCfg, bool barrierBeforeWrite) const {

    MemRefType destBufferType = cast<MemRefType>(destBuffer.getType());
    mlir::gpu::AddressSpace destBufferAddrSpace =
        cast<gpu::AddressSpaceAttr>(destBufferType.getMemorySpace()).getValue();
    bool isDestBufferLDS = destBufferAddrSpace == gpu::AddressSpace::Workgroup;
    if (!isDestBufferLDS && destBufferAddrSpace != gpu::AddressSpace::Private) {
      return emitError(loc) << "the destination buffer to load global input "
                               "tile should either be LDS or Regs.\n";
    }

    int64_t kPerBlock = kpacksPerBlock * kpack;
    int64_t copyPerThread = (kPerBlock * dPerBlock) / blockSize;
    int64_t kGlobal = cast<MemRefType>(in.getType()).getShape()[1];
    int64_t kIters = kGlobal / kPerBlock;
    if (copyPerThread == 0) {
      return emitError(loc) << "Block size too large, rejecting as invalid.\n";
    }
    FailureOr<VectorDimInfo> maybeVectorDimInfo = getVectorDim(
        rewriter, loc, in, elemType, blockSize, kPerBlock, dPerBlock, kpack);
    if (failed(maybeVectorDimInfo)) {
      return failure();
    }
    GemmDimension vectorDim = maybeVectorDimInfo->vectorDim;
    FailureOr<RegsAsMatrixSubTiles> maybeInBufferViews;
    if (!isDestBufferLDS) {
      maybeInBufferViews = accelEmitter.createAccelGemmOperandTransforms(
          rewriter, loc, kIters, bidGridLengths, blockSize,
          maybeVectorDimInfo->inDPerThread, nonKDimName,
          vectorDim == GemmDimension::K, false);
    } else {
      maybeInBufferViews = getLoadRegsAsTileViews(
          rewriter, loc, in, nonKDimName, bidGridOrder, bidGridLengths,
          blockSize, kPerBlock, dPerBlock, maybeVectorDimInfo->inKPerThread,
          maybeVectorDimInfo->inDPerThread, vectorDim == GemmDimension::K);
    }
    if (failed(maybeInBufferViews)) {
      return failure();
    }
    Value viewIn = transform(rewriter, in, maybeInBufferViews->gridSubTile);
    auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());
    ThreadwiseReadIntoOp::create(
        rewriter, loc, vectorOfBoolShapedLike(fromGlobalRegBuffer), viewIn,
        fromGlobalRegBuffer,
        /*dynamicValidities=*/ValueRange{},
        /*extraViews=*/rewriter.getArrayAttr({}),
        ValueRange{kIter, gridCoords.g_block, gridCoords.m_block,
                   gridCoords.n_block, tid},
        forceUnroll, true);
    if (isDestBufferLDS) {
      // threadwiseView is iter --> K,D
      // Hence we invert to create the reg buffer to be viewed
      // as K x D memref
      ArrayAttr loadBufferViews =
          invertTransforms(rewriter, loc, maybeInBufferViews->threadSubTile);
      Value viewLoadBuffer =
          transform(rewriter, fromGlobalRegBuffer, loadBufferViews);

      FailureOr<RegsAsMatrixSubTiles> maybeLdsStoreViews =
          getPackedRegsAsTileViews(rewriter, loc, in, nonKDimName, bidGridOrder,
                                   bidGridLengths, blockSize, kPerBlock,
                                   dPerBlock, maybeVectorDimInfo->inKPerThread,
                                   maybeVectorDimInfo->inDPerThread, kpack,
                                   vectorDim == GemmDimension::K,
                                   ldsLayoutCfg.doSwapThreadIterSubDims);
      if (failed(maybeLdsStoreViews)) {
        return failure();
      }

      LogicalResult storeGemmTileStatus = storeGemmInputTile(
          rewriter, loc, kpack, viewLoadBuffer, maybeLdsStoreViews.value(),
          toLDSRegBuffer, destBuffer, kpacksPerBlock, nonKDimName, kPerBlock,
          dPerBlock, maybeVectorDimInfo->inKPerThread,
          maybeVectorDimInfo->inDPerThread, forceUnroll,
          ldsLayoutCfg.doRotateWithK, barrierBeforeWrite);
      if (failed(storeGemmTileStatus)) {
        return failure();
      }
    } else {
      assert(!ldsLayoutCfg.doSwapThreadIterSubDims &&
             "doSwapThreadIterSubDims must be false if the destination buffer "
             "is private memory");
      assert(!barrierBeforeWrite &&
             "can't add a LDS barrier if the destination buffer is not LDS");
      accel::AccelEmitterParams accelEmitterParams = accelEmitter.getParams();
      int64_t dRepeats = (nonKDimName == "m" ? accelEmitterParams.mRepeats
                                             : accelEmitterParams.nRepeats);
      affine::AffineForOp dRepeatsLoop =
          affine::AffineForOp::create(rewriter, loc, 0, dRepeats, 1);
      {
        PatternRewriter::InsertionGuard guard(rewriter);
        rewriter.setInsertionPointToStart(dRepeatsLoop.getBody());
        Value di = dRepeatsLoop.getInductionVar();
        Value subview = destBuffer;
        if (dRepeats > 1) {
          subview = createSliceOfFirstDim(rewriter, loc, destBuffer, di);
        }
        // InBufferViews provide --> K x D subtile views.
        // Since we are iterating on D dimension, we need to transpose it.
        RegsAsMatrixSubTiles inBufferViewsTr =
            transposeSubTileViews(rewriter, loc, maybeInBufferViews.value());
        Value viewLoadedBuffer = transform(
            rewriter, fromGlobalRegBuffer,
            invertTransforms(rewriter, loc, inBufferViewsTr.threadSubTile));
        ThreadwiseReadIntoOp::create(rewriter, loc, viewLoadedBuffer, subview,
                                     rewriter.getArrayAttr({}), ValueRange{di},
                                     true, true);
      }
    }
    return success();
  }

  Value createLDSByteBuffer(PatternRewriter &rewriter, Location loc,
                            int64_t numElements, Type elemType) const {
    auto workgroupMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getWorkgroupAddressSpace());
    int64_t ldsBlockSize = numElements * getByteWidth(elemType);
    auto ldsMemRefType =
        MemRefType::get({ldsBlockSize}, rewriter.getI8Type(), AffineMap{},
                        workgroupMemoryAddressSpace);
    Value ldsByteBuffer = GpuAllocOp::create(rewriter, loc, ldsMemRefType);
    return ldsByteBuffer;
  }

  // This function will create fromGlobalRegsBuffer and toLDSRegBuffer for a
  // gemm input
  std::tuple<Value, Value>
  createRegBuffersForGemmIn(Location loc, int64_t kPerBlock, int64_t blockSize,
                            Type elemType, int64_t dPerBlock,
                            PatternRewriter &rewriter) const {
    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    int64_t copyPerThread = (kPerBlock * dPerBlock) / blockSize;
    Type loadBufferType = MemRefType::get(
        {copyPerThread}, elemType, AffineMap{}, privateMemoryAddressSpace);
    Value fromGlobalRegBuffer =
        GpuAllocOp::create(rewriter, loc, loadBufferType);
    Value toLDSRegBuffer = GpuAllocOp::create(rewriter, loc, loadBufferType);
    return {fromGlobalRegBuffer, toLDSRegBuffer};
  }

  void zeroAccBuffer(PatternRewriter &rewriter, Location loc,
                     Value accBuffer) const {
    MemRefType accBufferType = cast<MemRefType>(accBuffer.getType());
    Value zeroConstantCOp =
        createZeroConstantOp(rewriter, loc, accBufferType.getElementType());
    FillOp::create(rewriter, loc, accBuffer, zeroConstantCOp);
  }

  // This function creates the accumulator register buffer
  Value createBufferForAccelGemmOut(Location loc,
                                    rock::accel::AccelEmitterParams params,
                                    PatternRewriter &rewriter,
                                    int64_t numBuffers = 1) const {
    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    int64_t nResultVectors = params.nResultVectors;
    int64_t mRepeats = params.mRepeats;
    int64_t nRepeats = params.nRepeats;
    VectorType accVectorType = params.accVectorType;
    int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats;
    MemRefType regCAllocType;
    if (numBuffers > 1) {
      regCAllocType = MemRefType::get(
          {numBuffers, nOutputVectors}, accVectorType, AffineMap{},
          /*memorySpace=*/privateMemoryAddressSpace);
    } else {
      regCAllocType =
          MemRefType::get(nOutputVectors, accVectorType, AffineMap{},
                          /*memorySpace=*/privateMemoryAddressSpace);
    }
    Value regCAllocOp = rock::GpuAllocOp::create(rewriter, loc, regCAllocType);
    return regCAllocOp;
  }

  // This function creates a simple scalar reg buffer (i.e. without vectors)
  Value createBufferForGemmOut(Location loc, Type gemmOutElemType,
                               rock::accel::AccelEmitterParams params,
                               PatternRewriter &rewriter,
                               int64_t numBuffers = 1) const {
    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    int64_t numOutputElements = params.numOutputVectorElements();
    MemRefType gemmOutScalarBufferType;
    if (numBuffers > 1) {
      gemmOutScalarBufferType = MemRefType::get(
          {numBuffers, numOutputElements}, gemmOutElemType, AffineMap{},
          /*memorySpace=*/privateMemoryAddressSpace);
    } else {
      gemmOutScalarBufferType =
          MemRefType::get({numOutputElements}, gemmOutElemType, AffineMap{},
                          /*memorySpace=*/privateMemoryAddressSpace);
    }
    Value gemmOutScalarBuffer =
        rock::GpuAllocOp::create(rewriter, loc, gemmOutScalarBufferType);
    return gemmOutScalarBuffer;
  }

  // This fuction creates interrim register buffers to store data in once
  // loaded from the LDS before accelerator intrinsics are called
  std::tuple<Value, Value> createRegInterrimBufferForAccel(
      Location loc, rock::accel::AccelEmitterParams params,
      PatternRewriter &rewriter, int64_t mRepeats = 1,
      int64_t nRepeats = 1) const {
    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    int64_t kBasePerThread = params.kBasePerThread;

    SmallVector<Value> arrayARegs;
    Type argTypeA = params.argTypeA;
    SmallVector<int64_t, 2> aShape{kBasePerThread};
    if (mRepeats > 1) {
      aShape.insert(aShape.begin(), mRepeats);
    }
    auto arrayAType = MemRefType::get(aShape, argTypeA, AffineMap{},
                                      privateMemoryAddressSpace);
    auto arrayA = GpuAllocOp::create(rewriter, loc, arrayAType);

    SmallVector<Value> arrayBRegs;
    Type argTypeB = params.argTypeB;
    SmallVector<int64_t, 2> bShape{kBasePerThread};
    if (nRepeats > 1) {
      bShape.insert(bShape.begin(), nRepeats);
    }
    auto arrayBType = MemRefType::get(bShape, argTypeB, AffineMap{},
                                      privateMemoryAddressSpace);
    auto arrayB = GpuAllocOp::create(rewriter, loc, arrayBType);
    return {arrayA, arrayB};
  }

  // This function computes exp(gemm0 - rowmax_j)
  void expSubstractMaxFromGemm0(PatternRewriter &rewriter, Location loc,
                                Value gemm0OutThreadwiseView,
                                Value gemm0OutExpThreadwiseView,
                                Value gemm0OutBufferMaxView,
                                Value maxRowBuffer) const {
    Value gemm0OutBufferMax, gemm0OutExp, gemm0Out;
    ArrayAttr gemm0OutBufferMaxTrs, gemm0OutExpTrs, gemm0OutTrs;
    std::tie(gemm0OutBufferMax, gemm0OutBufferMaxTrs, std::ignore) =
        untransform(rewriter, gemm0OutBufferMaxView);
    std::tie(gemm0OutExp, gemm0OutExpTrs, std::ignore) =
        untransform(rewriter, gemm0OutExpThreadwiseView);
    std::tie(gemm0Out, gemm0OutTrs, std::ignore) =
        untransform(rewriter, gemm0OutThreadwiseView);

    MemRefType gemm0OutViewType =
        cast<MemRefType>(gemm0OutThreadwiseView.getType());
    int64_t g0Mpt = gemm0OutViewType.getShape()[0];
    int64_t g0Npt = gemm0OutViewType.getShape()[1];

    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
    auto loop = TransformingForOp::create(
        rewriter, loc,
        ArrayRef<ValueRange>{
            {zero, zero}, {zero, zero}, {zero, zero}, {zero, zero}},
        ArrayRef<Attribute>{rewriter.getArrayAttr({}), gemm0OutBufferMaxTrs,
                            gemm0OutExpTrs, gemm0OutTrs},
        /*bounds=*/ArrayRef<int64_t>{g0Mpt, g0Npt},
        /*strides=*/ArrayRef<int64_t>{1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());
      Block::BlockArgListType upperCoords = loop.getLowerCoords(0);
      Block::BlockArgListType gemm0OutBufferMaxCoords = loop.getLowerCoords(1);
      Block::BlockArgListType gemm0OutExpCoords = loop.getLowerCoords(2);
      Block::BlockArgListType gemm0OutCoords = loop.getLowerCoords(3);

      // maxRowBufferNew = max(maxRowBuffer, gemm0OutBufferMaxView[:,0])
      Type maxRowBufferElemType = getElementTypeOrSelf(maxRowBuffer.getType());
      Value ldMaxRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, maxRowBufferElemType,
                                 maxRowBuffer, ValueRange{upperCoords[0]});
      Value ldgemm0OutBufferMax =
          InBoundsLoadOp::create(rewriter, loc, maxRowBufferElemType,
                                 gemm0OutBufferMax, gemm0OutBufferMaxCoords);
      Value maxRowBufferNew = arith::MaximumFOp::create(
          rewriter, loc, ldMaxRowBuffer, ldgemm0OutBufferMax);

      // ldGemm0OutSubMaxExp = exp(gemm0Out  -maxRowBufferNew)
      Type ldGemm0OutElemType = getElementTypeOrSelf(gemm0Out.getType());
      Value ldGemm0Out = InBoundsLoadOp::create(
          rewriter, loc, ldGemm0OutElemType, gemm0Out, gemm0OutCoords);
      Value ldGemm0OutSubMax =
          arith::SubFOp::create(rewriter, loc, ldGemm0Out, maxRowBufferNew);
      Value ldGemm0OutSubMaxExp =
          math::Exp2Op::create(rewriter, loc, ldGemm0OutSubMax);

      // Store back to gemm0Out
      InBoundsStoreOp::create(rewriter, loc, ldGemm0OutSubMaxExp, gemm0OutExp,
                              gemm0OutExpCoords);
    }
  }

  // This updates the row sum according to the following
  // formula:
  // li = exp(m_{j-1} - m_{j}) * l_{j-1} + rowsum(Pij)
  // where
  // l is the rowsum accumulator
  // m is the rowmax accmulator
  // P is exp(gemm0 - rowmax_j)
  void updateRowSum(PatternRewriter &rewriter, Location loc,
                    Value gemm0OutBufferSumView, Value gemm0OutBufferMaxView,
                    Value sumRowBuffer, Value maxRowBuffer,
                    Value expMaxDiffRowBuffer) const {
    Value gemm0OutBufferSum, gemm0OutBufferMax;
    ArrayAttr gemm0OutBufferSumTrs, gemm0OutBufferMaxTrs;
    std::tie(gemm0OutBufferMax, gemm0OutBufferMaxTrs, std::ignore) =
        untransform(rewriter, gemm0OutBufferMaxView);
    std::tie(gemm0OutBufferSum, gemm0OutBufferSumTrs, std::ignore) =
        untransform(rewriter, gemm0OutBufferSumView);

    MemRefType gemm0OutViewType =
        cast<MemRefType>(gemm0OutBufferSumView.getType());
    int64_t g0Npt = gemm0OutViewType.getShape()[0];
    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
    auto loop = TransformingForOp::create(
        rewriter, loc,
        ArrayRef<ValueRange>{{zero, zero}, {zero, zero}, {zero, zero}},
        ArrayRef<Attribute>{rewriter.getArrayAttr({}), gemm0OutBufferSumTrs,
                            gemm0OutBufferMaxTrs},
        /*bounds=*/ArrayRef<int64_t>{g0Npt, 1},
        /*strides=*/ArrayRef<int64_t>{1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());
      Block::BlockArgListType upperCoords = loop.getLowerCoords(0);
      Block::BlockArgListType gemm0OutBufferSumCoords = loop.getLowerCoords(1);
      Block::BlockArgListType gemm0OutBufferMaxCoords = loop.getLowerCoords(2);
      // sumRowBufferNew = exp(maxRowBuffer - maxRowBufferNew) * sumRowBuffer +
      // exp(gemm0OutBufferMaxView[:,0] - maxRowBufferNew) *
      // gemm0OutBufferSumView[:,0]
      Type sumRowBufferElemType = getElementTypeOrSelf(sumRowBuffer.getType());
      Value ldSumRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, sumRowBufferElemType,
                                 sumRowBuffer, ValueRange{upperCoords[0]});
      Value ldgemm0OutBufferSum =
          InBoundsLoadOp::create(rewriter, loc, sumRowBufferElemType,
                                 gemm0OutBufferSum, gemm0OutBufferSumCoords);
      // sumRowBufferNew0 = exp(maxRowBuffer - maxRowBufferNew) * sumRowBuffer
      Type maxRowBufferElemType = getElementTypeOrSelf(maxRowBuffer.getType());
      Value ldMaxRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, maxRowBufferElemType,
                                 maxRowBuffer, ValueRange{upperCoords[0]});
      Value ldgemm0OutBufferMax =
          InBoundsLoadOp::create(rewriter, loc, maxRowBufferElemType,
                                 gemm0OutBufferMax, gemm0OutBufferMaxCoords);
      Value maxRowBufferNew = arith::MaximumFOp::create(
          rewriter, loc, ldMaxRowBuffer, ldgemm0OutBufferMax);
      Value maxRowDiff =
          arith::SubFOp::create(rewriter, loc, ldMaxRowBuffer, maxRowBufferNew);
      Value maxRowDiffExp = math::Exp2Op::create(rewriter, loc, maxRowDiff);
      InBoundsStoreOp::create(rewriter, loc, maxRowDiffExp, expMaxDiffRowBuffer,
                              ValueRange{upperCoords[0]});
      Value sumRowBufferNew = maxRowDiffExp;
      sumRowBufferNew =
          arith::MulFOp::create(rewriter, loc, sumRowBufferNew, ldSumRowBuffer);
      sumRowBufferNew = arith::AddFOp::create(rewriter, loc, sumRowBufferNew,
                                              ldgemm0OutBufferSum);
      InBoundsStoreOp::create(rewriter, loc, sumRowBufferNew, sumRowBuffer,
                              ValueRange{upperCoords[0]});
      InBoundsStoreOp::create(rewriter, loc, maxRowBufferNew, maxRowBuffer,
                              ValueRange{upperCoords[0]});
    }
  }

  // This computes LSE (log-sum-exp)
  // Note that this happens at the end of the kernel, so m and l are not running
  // sum/max anymore. They are the final values.
  // input = gemm0 output
  // x = input/log(2) -> we divide by log(2) to be able to use exp2()
  // m = max x
  // l = sum exp2(x-m)
  // We want to compute log(sum e^x), therefore we do:
  // log(l*exp2(m)) = (log2(l) + m)*log(2) -> we use exp2() for "m", because we
  // need to use the same exp function used for "l"
  void computeLse(PatternRewriter &rewriter, Location loc, Value lseBufferView,
                  Value sumRowBuffer, Value maxRowBuffer) const {
    MemRefType memrefType = cast<MemRefType>(sumRowBuffer.getType());
    assert(maxRowBuffer.getType() == sumRowBuffer.getType());

    Type inputElemType = memrefType.getElementType();

    Value lseBuffer;
    ArrayAttr lseBufferTrs;
    std::tie(lseBuffer, lseBufferTrs, std::ignore) =
        untransform(rewriter, lseBufferView);
    MemRefType lseBufferViewType = cast<MemRefType>(lseBufferView.getType());
    Type outElemType = lseBufferViewType.getElementType();
    int64_t g1Npt = lseBufferViewType.getShape()[0];
    int64_t g1Mpt = lseBufferViewType.getShape()[1];
    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
    Value ln2Const = createConstantFloatOp(
        rewriter, loc, outElemType, outElemType, 0.69314718f,
        outElemType.getIntOrFloatBitWidth() >= 32 ? APFloat::opOK
                                                  : APFloat::opInexact);
    auto loop = TransformingForOp::create(
        rewriter, loc, ArrayRef<ValueRange>{{zero, zero}, {zero, zero}},
        ArrayRef<Attribute>{rewriter.getArrayAttr({}), lseBufferTrs},
        /*bounds=*/ArrayRef<int64_t>{g1Npt, g1Mpt},
        /*strides=*/ArrayRef<int64_t>{1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());
      // lower = upper because the transform is empty
      Block::BlockArgListType upperCoords = loop.getLowerCoords(0);
      Block::BlockArgListType lseBufferCoords = loop.getLowerCoords(1);

      Value ldMaxRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, inputElemType, maxRowBuffer,
                                 ValueRange{upperCoords[0]});
      Value ldSumRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, inputElemType, sumRowBuffer,
                                 ValueRange{upperCoords[0]});

      // convert to LSE type
      ldMaxRowBuffer =
          createTypeConversionOp(rewriter, loc, ldMaxRowBuffer, outElemType);
      ldSumRowBuffer =
          createTypeConversionOp(rewriter, loc, ldSumRowBuffer, outElemType);
      // lse_i = (log2(l_i) + m_i)*log(2)
      // Migraphx expects LSE to be log
      Value log2Li = math::Log2Op::create(rewriter, loc, ldSumRowBuffer);
      Value log2Mi = ldMaxRowBuffer;
      Value lseLog2 = arith::AddFOp::create(rewriter, loc, log2Li, log2Mi);
      Value lseOut = arith::MulFOp::create(rewriter, loc, lseLog2, ln2Const);
      InBoundsStoreOp::create(rewriter, loc, lseOut, lseBuffer,
                              lseBufferCoords);
    }
  }

  // This is the out of loop scaling of attention output
  // where its divided by the accumulated rowsum
  void scaleFinalOutput(PatternRewriter &rewriter, Location loc,
                        Value attentionOutAccBufferView,
                        Value sumRowBuffer) const {
    Value attentionOutAccBuffer;
    ArrayAttr attentionOutAccTrs;
    std::tie(attentionOutAccBuffer, attentionOutAccTrs, std::ignore) =
        untransform(rewriter, attentionOutAccBufferView);
    MemRefType attentionOutAccViewType =
        cast<MemRefType>(attentionOutAccBufferView.getType());
    Type outElemType = attentionOutAccViewType.getElementType();
    int64_t g1Npt = attentionOutAccViewType.getShape()[0];
    int64_t g1Mpt = attentionOutAccViewType.getShape()[1];
    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
    auto loop = TransformingForOp::create(
        rewriter, loc, ArrayRef<ValueRange>{{zero, zero}, {zero, zero}},
        ArrayRef<Attribute>{rewriter.getArrayAttr({}), attentionOutAccTrs},
        /*bounds=*/ArrayRef<int64_t>{g1Npt, g1Mpt},
        /*strides=*/ArrayRef<int64_t>{1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());
      Block::BlockArgListType upperCoords = loop.getLowerCoords(0);
      Block::BlockArgListType attentionOutAccBufferCoords =
          loop.getLowerCoords(1);
      Value ldAttentionOutAccBuffer = InBoundsLoadOp::create(
          rewriter, loc, outElemType, attentionOutAccBuffer,
          attentionOutAccBufferCoords);
      Type sumRowBufferElemType = getElementTypeOrSelf(sumRowBuffer.getType());
      Value ldSumRowBuffer =
          InBoundsLoadOp::create(rewriter, loc, sumRowBufferElemType,
                                 sumRowBuffer, ValueRange{upperCoords[0]});
      Value stAttentionOutAccBuffer = arith::DivFOp::create(
          rewriter, loc, ldAttentionOutAccBuffer, ldSumRowBuffer);
      InBoundsStoreOp::create(rewriter, loc, stAttentionOutAccBuffer,
                              attentionOutAccBuffer,
                              attentionOutAccBufferCoords);
    }
  }

  // This function does the corrections to row-based tiled reductions
  // according to flash attention 2 algorithm :
  // https://arxiv.org/pdf/2205.14135.pdf
  //
  // The shapes expected by the functions:
  // gemm0OutBufferMaxView.shape = [g0.Mpt, g0.Npt]
  // gemm1OutThreadwiseView.shape = [g1.Mpt=g0.Mpt, g1.Npt]
  // attentionOutAccBuffer.shape = [g1.Mpt=g0.Mpt, g1.Npt]
  //
  // This function will do the following logic :
  //
  // maxRowBufferNew = max(maxRowBuffer, gemm0OutBufferMaxView[:,0])
  // expMaxDiff = exp(maxRowBuffer - maxRowBufferNew)
  // attentionOutAccBufferMaxScaled = if not first iter ? attentionOutAccBuffer
  // / expMaxDiff : attentionOutAccBuffer attentionOutAccBufferMaxScaled +=
  // gemm1OutThreadwiseView [STORE] attentionOutAccBuffer =
  // attentionOutAccBufferMaxScaled
  void createAttentionRowStateCorrections(PatternRewriter &rewriter,
                                          Location loc,
                                          Value gemm1OutThreadwiseView,
                                          Value attentionOutAccBufferView,
                                          Value expMaxDiffRowBuffer) const {
    Value gemm1Out, attentionOutAccBuffer;
    ArrayAttr gemm1OutTrs, attentionOutAccBufferTrs;
    std::tie(gemm1Out, gemm1OutTrs, std::ignore) =
        untransform(rewriter, gemm1OutThreadwiseView);
    std::tie(attentionOutAccBuffer, attentionOutAccBufferTrs, std::ignore) =
        untransform(rewriter, attentionOutAccBufferView);

    MemRefType attentionOutAccBufferType =
        cast<MemRefType>(attentionOutAccBufferView.getType());
    Type outElemType = attentionOutAccBufferType.getElementType();
    int64_t g1Npt = attentionOutAccBufferType.getShape()[0];
    int64_t g1Mpt = attentionOutAccBufferType.getShape()[1];

    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);

    auto loop = TransformingForOp::create(
        rewriter, loc,
        ArrayRef<ValueRange>{{zero, zero}, {zero, zero}, {zero, zero}},
        ArrayRef<Attribute>{rewriter.getArrayAttr({}), gemm1OutTrs,
                            attentionOutAccBufferTrs},
        /*bounds=*/ArrayRef<int64_t>{g1Npt, g1Mpt},
        /*strides=*/ArrayRef<int64_t>{1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());

      Block::BlockArgListType upperCoords = loop.getLowerCoords(0);
      Block::BlockArgListType gemm1OutCoords = loop.getLowerCoords(1);
      Block::BlockArgListType attentionOutAccBufferCoords =
          loop.getLowerCoords(2);

      Type expMaxDiffRowBufferElemType =
          getElementTypeOrSelf(expMaxDiffRowBuffer.getType());
      Value maxRowDiffExp = InBoundsLoadOp::create(
          rewriter, loc, expMaxDiffRowBufferElemType, expMaxDiffRowBuffer,
          ValueRange{upperCoords[0]});
      Value ldAttentionOutAccBuffer = InBoundsLoadOp::create(
          rewriter, loc, outElemType, attentionOutAccBuffer,
          attentionOutAccBufferCoords);
      Value scaledldAttentionOutAccBuffer = arith::MulFOp::create(
          rewriter, loc, ldAttentionOutAccBuffer, maxRowDiffExp);

      Value ldGemm1Out = InBoundsLoadOp::create(rewriter, loc, outElemType,
                                                gemm1Out, gemm1OutCoords);
      Value stAttentionOutAccBuffer = arith::AddFOp::create(
          rewriter, loc, scaledldAttentionOutAccBuffer, ldGemm1Out);
      InBoundsStoreOp::create(rewriter, loc, stAttentionOutAccBuffer,
                              attentionOutAccBuffer,
                              attentionOutAccBufferCoords);
    }
  }

  // This function will take a view stack that has lower view as m x n.
  // Then append a view to make it : m x n --> m --> m x constDim(0, n).
  // This is used to get corresponding 0th col idx in between two matrices
  // that have same number of rows.
  ArrayAttr createNZeroBroadcastView(PatternRewriter &rewriter, Location loc,
                                     ArrayAttr subTileView,
                                     int64_t zeroNDimSize) const {
    ArrayRef<int64_t> lowerShape = getLowerShape(subTileView);
    bool hasGDim = lowerShape.size() == 3;
    SmallVector<StringRef> topNames{"m", "n"};
    int nDimIdx = 1;
    if (hasGDim) {
      topNames.insert(topNames.begin(), "g");
      nDimIdx = 2;
    }
    TopDownTMBuilder dropNTop(rewriter, topNames, lowerShape, loc);
    if (hasGDim) {
      dropNTop.passThrough("g");
    }
    dropNTop.passThrough("m");
    dropNTop.constDim("nzero", nDimIdx, 0, zeroNDimSize);
    TransformMapAttr mOnlyViewMap = dropNTop.get();
    return prependUpperViews(rewriter, subTileView,
                             rewriter.getArrayAttr({mOnlyViewMap}));
  }

  // This function will call makeNZeroSubTile on subtile views of registers
  // across grid, block and thread levels.
  RegsAsMatrixSubTiles makeNZeroSubTile(PatternRewriter &rewriter, Location loc,
                                        RegsAsMatrixSubTiles subTileViews,
                                        int64_t nLen, int64_t nPerBlock,
                                        int64_t nPerThread) const {
    RegsAsMatrixSubTiles ret;
    ret.gridSubTile =
        createNZeroBroadcastView(rewriter, loc, subTileViews.gridSubTile, nLen);
    ret.blockSubTile = createNZeroBroadcastView(
        rewriter, loc, subTileViews.blockSubTile, nPerBlock);
    ret.threadSubTile = createNZeroBroadcastView(
        rewriter, loc, subTileViews.threadSubTile, nPerThread);
    return ret;
  }

  // This function will create a grid subtile view that has the unpadded
  // coordinates if there were any padding involved in the gemm operands.
  RegsAsMatrixSubTiles unpadGridSubTileView(PatternRewriter &rewriter,
                                            Location loc,
                                            RegsAsMatrixSubTiles subtileViews,
                                            int64_t prePadDim1,
                                            int64_t prePadDim2) const {
    ArrayRef<int64_t> paddedShape = getLowerShape(subtileViews.gridSubTile);
    TopDownTMBuilder viewBuilder{
        rewriter, {"g", "paddedDim1", "paddedDim2"}, paddedShape, loc};
    viewBuilder.passThrough("g");
    // paddedShape is G x M x N
    viewBuilder.pad(
        {"paddedDim1", "paddedDim2"},
        {0, paddedShape[1] - prePadDim1, 0, paddedShape[2] - prePadDim2});
    TransformMapAttr padMap = viewBuilder.get();

    subtileViews.gridSubTile = prependUpperViews(
        rewriter, subtileViews.gridSubTile, rewriter.getArrayAttr({padMap}));
    return subtileViews;
  }

  // If padding is used in the kernel, this means the first gemm
  // will be done in a larger matrix. In typical, gemm kernels
  // the padded region in the output will just contain zeros. However,
  // attention kernel will perform softmax normalization on rows.
  // Therefore, having zeros -- zero not being the minimum representable
  // value in the element type -- going to affect all the values
  // post normalization. Therefore, this function creates a transforming
  // for loop that overwrites out of bounds values of first gemm output
  // to be negative infinity.
  void createFirstGemmNegInfPadding(
      PatternRewriter &rewriter, Location loc,
      layout::GridCoordinates gridCoords, Value gemm0OutBuffer,
      RegsAsMatrixSubTiles gemm0OutSubTileViews) const {
    MemRefType gemm0OutBufferType = cast<MemRefType>(gemm0OutBuffer.getType());
    auto negInfTyped = createConstantFloatOp(
        rewriter, loc, gemm0OutBufferType.getElementType(),
        gemm0OutBufferType.getElementType(),
        -std::numeric_limits<float>::infinity(), APFloat::opOK);
    // Get current workitem ID.
    auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());
    int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements();
    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);

    auto loop = TransformingForOp::create(
        rewriter, loc,
        ArrayRef<ValueRange>{{gridCoords.g_block, gridCoords.m_block,
                              gridCoords.n_block, tid, zero},
                             {zero, zero, zero, zero, zero}},
        ArrayRef<Attribute>{gemm0OutSubTileViews.gridSubTile,
                            rewriter.getArrayAttr({})},
        /*bounds=*/ArrayRef<int64_t>{1, 1, 1, 1, elementsInThreadBuffer},
        /*strides=*/ArrayRef<int64_t>{1, 1, 1, 1, 1},
        /*forceUnroll=*/true, /*useIndexDiffs=*/true);
    {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(loop.getBody());

      Block::BlockArgListType upperCoords = loop.getLowerCoords(1);
      TypedValue<IntegerType> isValid = loop.getValidity(0);
      Value zeroBit = createConstantIntOp(rewriter, loc, isValid.getType(),
                                          isValid.getType(), 0);
      auto isInvalid = arith::CmpIOp::create(
          rewriter, loc, arith::CmpIPredicate::eq, isValid, zeroBit);
      scf::IfOp ifb = scf::IfOp::create(rewriter, loc, isInvalid,
                                        /*withElseRegion=*/false);
      {
        OpBuilder thenb = ifb.getThenBodyBuilder();
        InBoundsStoreOp::create(thenb, loc, negInfTyped, gemm0OutBuffer,
                                ValueRange{upperCoords[4]});
      }
    }
  }

  enum class OutOfScopeType { KVCache, Causal };

  void setGemm0OutputOutOfScope(
      PatternRewriter &rewriter, Location loc, OutOfScopeType outOfScopeType,
      layout::GridCoordinates gridCoords, Value gemm0OutBuffer,
      RegsAsMatrixSubTiles gemm0OutSubTileViews, bool enabled, Value mLoopIV,
      Value gemm0MBlocksLastIter, Value currentSeqLen = nullptr,
      IntegerAttr numRepeatsGQA = nullptr) const {
    if (enabled) {
      auto isLastIteration =
          arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
                                mLoopIV, gemm0MBlocksLastIter);
      scf::IfOp ifb = scf::IfOp::create(rewriter, loc, isLastIteration,
                                        /*withElseRegion=*/false);
      {
        OpBuilder thenb = ifb.getThenBodyBuilder();

        Value constNumRepeatsGQA = nullptr;
        if (numRepeatsGQA)
          constNumRepeatsGQA = thenb.createOrFold<arith::ConstantIndexOp>(
              loc, numRepeatsGQA.getInt());

        MemRefType gemm0OutBufferType =
            cast<MemRefType>(gemm0OutBuffer.getType());
        auto negInfTyped = createConstantFloatOp(
            thenb, loc, gemm0OutBufferType.getElementType(),
            gemm0OutBufferType.getElementType(),
            -std::numeric_limits<float>::infinity(), APFloat::opOK);
        // Get current workitem ID.
        auto tid = WorkitemIdOp::create(thenb, loc, thenb.getIndexType());
        int64_t elementsInThreadBuffer = gemm0OutBufferType.getNumElements();
        Value zero = thenb.createOrFold<ConstantIndexOp>(loc, 0);
        auto loop = TransformingForOp::create(
            thenb, loc,
            ArrayRef<ValueRange>{{gridCoords.g_block, gridCoords.m_block,
                                  gridCoords.n_block, tid, zero},
                                 {zero, zero, zero, zero, zero}},
            ArrayRef<Attribute>{gemm0OutSubTileViews.gridSubTile,
                                thenb.getArrayAttr({})},
            /*bounds=*/ArrayRef<int64_t>{1, 1, 1, 1, elementsInThreadBuffer},
            /*strides=*/ArrayRef<int64_t>{1, 1, 1, 1, 1},
            /*forceUnroll=*/true, /*useIndexDiffs=*/true);
        {
          OpBuilder::InsertionGuard guard(thenb);
          thenb.setInsertionPointToStart(loop.getBody());

          Block::BlockArgListType lowerCoords = loop.getLowerCoords(0);
          Block::BlockArgListType upperCoords = loop.getLowerCoords(1);
          Value isInvalid;
          Value mIndex = lowerCoords[2];
          switch (outOfScopeType) {
          case OutOfScopeType::KVCache:
            assert(currentSeqLen != nullptr);
            isInvalid = arith::CmpIOp::create(thenb,
                loc, arith::CmpIPredicate::ugt, mIndex, currentSeqLen);
            break;
          case OutOfScopeType::Causal:
            Value nIndex = lowerCoords[1];
            if (constNumRepeatsGQA)
              nIndex = thenb.createOrFold<arith::DivUIOp>(loc, nIndex,
                                                          constNumRepeatsGQA);

            isInvalid = arith::CmpIOp::create(thenb,
                loc, arith::CmpIPredicate::ugt, mIndex, nIndex);
            break;
          }

          scf::IfOp ifb = scf::IfOp::create(thenb, loc, isInvalid,
                                            /*withElseRegion=*/false);
          {
            OpBuilder thenb = ifb.getThenBodyBuilder();
            InBoundsStoreOp::create(thenb, loc, negInfTyped, gemm0OutBuffer,
                                    ValueRange{upperCoords[4]});
          }
        }
      }
    }
  }

  template <typename ElementwiseOpType>
  void postProcessFirstGemmSplat(PatternRewriter &rewriter, Location loc,
                                 layout::GridCoordinates gridCoords,
                                 Value gemm0OutBuffer,
                                 RegsAsMatrixSubTiles gemm0OutViews,
                                 TypedAttr splatVal) const {
    MemRefType bufType = cast<MemRefType>(gemm0OutBuffer.getType());
    SmallVector<AffineMap, 2> indexingMaps{
        2, rewriter.getMultiDimIdentityMap(bufType.getRank())};
    SmallVector<utils::IteratorType> iteratorTypes(
        bufType.getRank(), utils::IteratorType::parallel);
    linalg::GenericOp::create(
        rewriter, loc, ValueRange(gemm0OutBuffer), ValueRange(gemm0OutBuffer),
        indexingMaps, iteratorTypes,
        [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
          Value splatScalarConst = arith::ConstantOp::create(
              nestedBuilder, loc, bufType.getElementType(), splatVal);
          Value elementwiseOp;
          if (bufType.getElementType().isIntOrIndex()) {
            elementwiseOp = ElementwiseOpType::Int::create(
                nestedBuilder, loc, args[0], splatScalarConst);
          } else {
            elementwiseOp = ElementwiseOpType::Float::create(
                nestedBuilder, loc, args[0], splatScalarConst);
          }
          linalg::YieldOp::create(nestedBuilder, nestedLoc, elementwiseOp);
        });
  }

  /// Undo GQA transforms for tensors of the fusion between first gemm and
  /// second gemm
  ArrayAttr undoGQATransforms(PatternRewriter &rewriter, Location loc,
                              GridwiseAttentionAccelOp op,
                              ArrayRef<int64_t> unpaddedShape) const {
    ArrayAttr gqaTransform = nullptr;
    if (op.getNumRepeatsGQAAttr()) {
      SmallVector<StringRef> startNames = {"gemmG", "seqLenQ", "seqLenKV"};
      int64_t numRepeats = op.getNumRepeatsGQAAttr().getInt();

      assert(unpaddedShape.size() == 3);
      int64_t gemmG = unpaddedShape[0];
      int64_t seqLenQ = unpaddedShape[1];
      int64_t seqLenKV = unpaddedShape[2];
      assert(seqLenQ % numRepeats == 0);

      // (gemmG, seqLenQ*numRepeats, seqLenKV) -> (gemmG, numRepeats, seqLenQ,
      // seqLenKV)
      rock::TopDownTMBuilder unmerge(rewriter, startNames,
                                     {gemmG, seqLenQ, seqLenKV});
      unmerge.merge({"seqLenQ", "numRepeats"}, {2, 1}, "seqLenQ",
                    {seqLenQ / numRepeats, numRepeats});
      unmerge.passThrough({"gemmG", "seqLenKV"}, {0, 3}, {"gemmG", "seqLenKV"});
      auto unmergeAttr = unmerge.get();

      // (gemmG, numRepeats, seqLenQ, seqLenKV) -> (gemmG*numRepeats, seqLenQ,
      // seqLenKV)
      auto merger = rock::TopDownTMBuilder::below(unmerge, unmergeAttr);
      merger.unmerge("gemmG", 0, {"gemmG", "numRepeats"}, {gemmG, numRepeats});
      merger.passThrough({"seqLenQ", "seqLenKV"}, {1, 2},
                         {"seqLenQ", "seqLenKV"});
      auto mergerAttr = merger.get();

      SmallVector<Attribute> transformAttrs{unmergeAttr, mergerAttr};
      gqaTransform = rewriter.getArrayAttr(transformAttrs);
    }
    return gqaTransform;
  }

  FailureOr<Value> postProcessFirstGemm(
      PatternRewriter &rewriter, Location loc, GridwiseAttentionAccelOp op,
      layout::GridCoordinates gridCoords, Value srcGemm0OutBuffer,
      Value destGemm0OutBuffer, RegsAsMatrixSubTiles gemm0OutViews) const {
    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());
    int64_t linalgOpIndex = -1;
    MemRefType srcBufType = cast<MemRefType>(srcGemm0OutBuffer.getType());
    MemRefType destBufType = cast<MemRefType>(destGemm0OutBuffer.getType());
    Value prevGemm0OutBuffer = srcGemm0OutBuffer;
    ArrayAttr linalgGridSubTileMaps = gemm0OutViews.gridSubTile;
    if (op.getPreSoftmaxBody().getBlocks().empty()) {
      // nothing to process
      return prevGemm0OutBuffer;
    }

    int64_t firstGemmBlockArgNum = -1;
    Block &preSoftMaxBodyBlock = op.getPreSoftmaxBody().getBlocks().front();
    WalkResult res = op.getPreSoftmaxBody().walk([&](linalg::GenericOp genOp) {
      linalgOpIndex++;
      auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());
      SmallVector<Value> inputTileBuffers;

      // Pull non-identiy index maps to rock transforms
      LogicalResult linalgIdentityRes =
          makeLinalgGenericWithIdentityAffMaps(rewriter, genOp);
      if (failed(linalgIdentityRes)) {
        genOp.emitError(
            "Failed to make linalg generic with identity affine maps");
        return WalkResult::interrupt();
      }

      // Obtain transform stack from gemmOutput to linalg generic input.
      ArrayAttr linalgToGemmOutMaps;
      Value gemm0BasedArg =
          genOp.getInputs()[op.getFirstGemmIndices()[linalgOpIndex]];
      Value mayBeFirstGemmBlockArg;
      std::tie(mayBeFirstGemmBlockArg, linalgToGemmOutMaps, std::ignore) =
          untransform(rewriter, gemm0BasedArg);

      // If the gemm0BasedArg is a block argument, we need to get its
      // blockArgNum
      if (auto firstGemmBlockArg =
              dyn_cast<BlockArgument>(mayBeFirstGemmBlockArg)) {
        assert(firstGemmBlockArgNum == -1 &&
               "firstGemmBlockArgNum should be set only once");
        // trace it back to block input
        if (firstGemmBlockArg.getOwner() == &preSoftMaxBodyBlock) {
          firstGemmBlockArgNum = firstGemmBlockArg.getArgNumber();
        } else {
          llvm::report_fatal_error("first gemm block argument does not belong "
                                   "to block of preSoftBody\n");
        }
      }
      // The obtained transforms will be linalg generic being the upperview
      // leading to gemmOutput being the lowerview. However, we need to
      // construct
      //  the following sequence :
      //  (bid, tid, iter) > ... > [gemmOutput: k x d]
      //                         > invertTr(linalg input to gemmOutput maps)
      //                         > (linalgOtherInput to op arg maps)
      ArrayAttr gemmOutToLinalgMaps =
          invertTransforms(rewriter, loc, linalgToGemmOutMaps);

      if (!gemmOutToLinalgMaps) {
        genOp.emitError("We can't invert linalg input to gemmOutput maps");
        return WalkResult::interrupt();
      }

      if (!gemmOutToLinalgMaps.empty()) {
        linalgGridSubTileMaps = prependUpperViews(
            rewriter, linalgGridSubTileMaps, gemmOutToLinalgMaps);
      }

      for (auto [idx, genOpInput] : llvm::enumerate(genOp.getInputs())) {
        if (idx ==
            static_cast<unsigned long>(op.getFirstGemmIndices()[linalgOpIndex]))
          continue;

        Value otherInput;
        ArrayAttr linalgToOtherInputMaps;
        std::tie(otherInput, linalgToOtherInputMaps, std::ignore) =
            untransform(rewriter, genOpInput);

        MemRefType otherInputBufType = cast<MemRefType>(otherInput.getType());
        MemRefType tileBufType = MemRefType::get(
            srcBufType.getShape(), otherInputBufType.getElementType(),
            AffineMap{}, privateMemoryAddressSpace);
        auto tileBuffer = rock::GpuAllocOp::create(rewriter, loc, tileBufType);

        ArrayAttr gemmOutToOtherInputMaps = linalgGridSubTileMaps;
        if (!linalgToOtherInputMaps.empty()) {
          gemmOutToOtherInputMaps = prependUpperViews(
              rewriter, linalgGridSubTileMaps, linalgToOtherInputMaps);
        }
        // If other input is a block argument of the attention op fusion
        if (auto blockArg = dyn_cast<BlockArgument>(otherInput)) {
          // trace it back to block input
          if (blockArg.getOwner() == &preSoftMaxBodyBlock) {
            int64_t blockArgNum = blockArg.getArgNumber();
            // we are processing other inputs. Block Argument number shouldn't
            // be the same as gemm input to first linalg generic op
            assert(firstGemmBlockArgNum != -1 &&
                   "firstGemmBlockArgNum should be set before processing other "
                   "inputs");
            assert(blockArgNum != firstGemmBlockArgNum);

            // if the gemm index is smaller, we need to substract one from the
            // index as `getPreSoftmaxElemWiseInputs()` doesn't contain
            // gemm0 output explictly
            if (blockArgNum > firstGemmBlockArgNum) {
              --blockArgNum;
            }
            otherInput = op.getPreSoftmaxElemWiseInputs()[blockArgNum];
          } else {
            llvm::report_fatal_error("Found blockArgument that does not belong "
                                     "to block of preSoftBody\n");
          }
        }
        ThreadwiseReadIntoOp::create(
            rewriter, loc, otherInput, tileBuffer, gemmOutToOtherInputMaps,
            ValueRange{gridCoords.g_block, gridCoords.m_block,
                       gridCoords.n_block, tid},
            true, true);
        inputTileBuffers.push_back(tileBuffer);
      }
      // Insert the first gemm output buffer according to which input
      // it was to the linalg generic
      inputTileBuffers.insert(inputTileBuffers.begin() +
                                  op.getFirstGemmIndices()[linalgOpIndex],
                              prevGemm0OutBuffer);
      Type outputType = genOp.getOutputs().back().getType();
      if (outputType != destGemm0OutBuffer.getType()) {
        MemRefType genOpOutMemrefType = cast<MemRefType>(outputType);
        MemRefType outTileBufType = MemRefType::get(
            destBufType.getShape(), genOpOutMemrefType.getElementType(),
            AffineMap{}, privateMemoryAddressSpace);
        auto outTileBuffer =
            rock::GpuAllocOp::create(rewriter, loc, outTileBufType);
        inputTileBuffers.push_back(outTileBuffer);
      } else {
        // reuse the same dest buffer if types match
        inputTileBuffers.push_back(destGemm0OutBuffer);
      }
      linalg::GenericOp newLinalgOp;

      mlir::IRMapping mapper;
      for (auto [operand, tilebuffer] :
           llvm::zip(genOp->getOperands(), inputTileBuffers)) {
        mapper.map(operand, tilebuffer);
      }
      newLinalgOp = cast<linalg::GenericOp>(rewriter.clone(*genOp, mapper));
      SmallVector<AffineMap> indexingMaps;
      for (size_t i = 0; i < inputTileBuffers.size(); i++) {
        indexingMaps.push_back(rewriter.getMultiDimIdentityMap(1));
      }
      newLinalgOp.setIndexingMapsAttr(
          rewriter.getAffineMapArrayAttr(indexingMaps));
      SmallVector<Attribute, 5> iteratorTypes;
      iteratorTypes.resize(
          1, linalg::IteratorTypeAttr::get(rewriter.getContext(),
                                           utils::IteratorType::parallel));
      newLinalgOp.setIteratorTypesAttr(rewriter.getArrayAttr(iteratorTypes));
      // set previous source buffer for the next linalg generic
      prevGemm0OutBuffer = inputTileBuffers.back();
      return WalkResult::advance();
    });
    if (res.wasInterrupted()) {
      return op.emitError("pre softmax linalg regularization failed.\n");
    }
    // if not linalg generic was found, we just return the srcBuffer
    if (linalgOpIndex == -1) {
      return srcGemm0OutBuffer;
    }
    assert(prevGemm0OutBuffer.getType() == destGemm0OutBuffer.getType() &&
           "after the regularization final output buffer type should match "
           "previously allocated fusion buffer type");
    assert(static_cast<size_t>(linalgOpIndex + 1) ==
               op.getFirstGemmIndices().size() &&
           "number of linalg generic ops and number of firstGemmIndices must "
           "match");
    return prevGemm0OutBuffer;
  }

  void loadGemmOperandsFromLDSToRegs(PatternRewriter &rewriter, Location loc,
                                     Value ldsTileBuffer,
                                     Value preAccelRegBuffer, StringRef dName,
                                     int64_t blockSize, int64_t inDPerThread,
                                     const accel::AccelEmitter &accelEmitterPtr,
                                     bool rotateDWithK) const {
    // Get current workitem ID.
    auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());
    rock::accel::AccelEmitterParams accelParams = accelEmitterPtr.getParams();
    Value wrappedLDSBufferForLoad = accelEmitterPtr.wrapLDSBufferForLoad(
        rewriter, loc, ldsTileBuffer, blockSize, inDPerThread, dName,
        rotateDWithK);
    int64_t repeats =
        dName == "m" ? accelParams.mRepeats : accelParams.nRepeats;
    affine::AffineForOp mRepeatsLoop =
        affine::AffineForOp::create(rewriter, loc, 0, repeats, 1);
    {
      PatternRewriter::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(mRepeatsLoop.getBody());
      Value mi = mRepeatsLoop.getInductionVar();
      Value subview = preAccelRegBuffer;
      if (repeats > 1) {
        subview = createSliceOfFirstDim(rewriter, loc, preAccelRegBuffer, mi);
      }
      ThreadwiseReadIntoOp::create(rewriter, loc, wrappedLDSBufferForLoad,
                                   subview, rewriter.getArrayAttr({}),
                                   ValueRange{tid, mi}, true, true);
    }
  }

  Value transposeAttnOperand(PatternRewriter &rewriter, Location loc,
                             TypedValue<MemRefType> operand) const {
    BottomUpTMBuilder viewBuilder(rewriter, operand.getType().getShape(), loc);
    viewBuilder.passThrough({0, 1, 2}, {0, 2, 1});
    TransformMapAttr trMap = viewBuilder.get();
    return TransformOp::create(rewriter, loc, operand, trMap);
  }

  /// Check whether the op can bypass LDS-based swizzling
  /// for the B operand of the second gemm.
  bool canBypassLDSForSecondGemm(GridwiseAttentionAccelOp op) const {
    Type elemTypeQ =
        cast<MemRefType>(op.getQueries().getType()).getElementType();
    Type elemTypeK = cast<MemRefType>(op.getKeys().getType()).getElementType();
    StringRef arch = rock::getArchValue(op);
    RockAccelTuningParamAttrInterface gemm0TuningParams = op.getParams0();
    auto accelEmitterPtrGemm0 = accel::AccelEmitter::select(
        rock::getFeatures(op), elemTypeQ, elemTypeK, arch, gemm0TuningParams);
    if (auto mfmaEmitter =
            dyn_cast<accel::MfmaEmitter>(accelEmitterPtrGemm0.get())) {
      if (!mfmaEmitter->isKReduction()) {
        return false;
      }
      int64_t mWaves =
          gemm0TuningParams.getMPerBlock() / gemm0TuningParams.getMPerWave();
      if (mWaves != 1) {
        return false;
      }
      // TODO: explore if this could be relaxed
      // Right now, the way we load thins from
      // LDS for the other operand distributes
      // kPack set of values from K dim. Therefore
      // to match with the MFMA output the Kpack
      // has to match rowGroupSize if we are to
      // avoid LDS for the current operand.
      if (gemm0TuningParams.getKpack() != mfmaEmitter->getRowGroupSize()) {
        return false;
      }
      return true;
    }
    return false;
  }

  /// check whether the op can bypass LDS when loading
  /// Q tiles to accel_gemm layouts
  bool canBypassLDSForQ(GridwiseAttentionAccelOp op) const {
    ArrayRef<int64_t> qShape =
        cast<MemRefType>(op.getQueries().getType()).getShape();
    int64_t gemm0K = qShape[1];
    RockAccelTuningParamAttrInterface gemm0TuningParams = op.getParams0();
    int64_t gemm0kpack = gemm0TuningParams.getKpack();
    int64_t gemm0KpacksPerBlock = gemm0TuningParams.getKpackPerBlock();
    int64_t gemm0KPerBlock = gemm0kpack * gemm0KpacksPerBlock;
    bool enableQLDSBypass = !op.getDisableQBypassLDS();
    return enableQLDSBypass && (gemm0K == gemm0KPerBlock);
  }

  TransformMapAttr getFlatToMiterMap(PatternRewriter &rewriter, int64_t gBlocks,
                                     int64_t mIterLen, int64_t nBlocks,
                                     int64_t blockSize,
                                     int64_t numElements) const {
    TopDownTMBuilder viewBuilder(rewriter,
                                 {"g_block", "n_block", "tid", "flatiter"},
                                 {gBlocks, nBlocks, blockSize, numElements});
    viewBuilder.passThrough({"g_block", "n_block", "tid"}, {0, 2, 3},
                            {"g_block", "n_block", "tid"});
    viewBuilder.merge({"mIter", "iter"}, {1, 4}, "flatiter",
                      {mIterLen, numElements / mIterLen});
    return viewBuilder.get();
  }

  std::tuple<Value, Value, Value, Value>
  getMLoopInfo(PatternRewriter &rewriter, Location loc,
               layout::AttnGridCoordinates gridCoordsGemm0,
               Value currentSeqLenTensor, int64_t gemm0M,
               int64_t gemm0MPerBlock, int64_t gemm0NPerBlock, int64_t splitKV,
               bool isCausal, bool isKVCache,
               IntegerAttr numRepeatsGQA = nullptr) const {
    Value gemm0MBlocksLastIter;
    Value currentSeqLen;
    Value effectiveSeqLen;
    Value start, end;
    // This is needed for KV Cache/Causal masking support
    if (isCausal || isKVCache) {
      Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
      if (isKVCache) {
        // add dim 1 for thread_read_into (registers)
        ArrayRef<int64_t> inpShape =
            cast<ShapedType>(currentSeqLenTensor.getType()).getShape();
        SmallVector<StringRef> startNames = {"gemmG"};
        rock::BottomUpTMBuilder addDim(rewriter, startNames, inpShape);
        addDim.addDim("dummy", 1, 1);
        addDim.passThrough(ArrayRef<uint32_t>{0}, ArrayRef<uint32_t>{0});
        auto addDimAttr = addDim.get();
        Value currentSeqLenTensorAddDim = rock::TransformOp::create(
            rewriter, loc, currentSeqLenTensor, addDimAttr);
        Type currentSeqLenElemType =
            getElementTypeOrSelf(currentSeqLenTensorAddDim.getType());

        // create registers
        auto privateMemoryAddressSpace =
            rewriter.getAttr<gpu::AddressSpaceAttr>(
                gpu::GPUDialect::getPrivateAddressSpace());
        auto memrefType = MemRefType::get(
            {1}, currentSeqLenElemType, AffineMap{}, privateMemoryAddressSpace);
        auto currentSeqLenLoad = GpuAllocOp::create(rewriter, loc, memrefType);

        // load from memory to registers
        ThreadwiseReadIntoOp::create(
            rewriter, loc, vectorOfBoolShapedLike(currentSeqLenLoad),
            currentSeqLenTensorAddDim, currentSeqLenLoad,
            /*dynamicValidities=*/ValueRange{},
            /*extraViews=*/rewriter.getArrayAttr({}),
            /*extraIndices=*/
            ValueRange{gridCoordsGemm0.g_block}, true, true);

        // load from registers
        Value currentSeqLenValue =
            InBoundsLoadOp::create(rewriter, loc, currentSeqLenElemType,
                                   currentSeqLenLoad, ValueRange{zero});
        currentSeqLen = rewriter.createOrFold<arith::IndexCastOp>(
            loc, rewriter.getIndexType(), currentSeqLenValue);
        effectiveSeqLen = currentSeqLen;
      }
      if (isCausal) {
        // this computes the maximum n of the block
        Value nIndex = gridCoordsGemm0.n_block;
        Value constGemm0NPerBlock =
            rewriter.createOrFold<arith::ConstantIndexOp>(loc, gemm0NPerBlock);
        Value maxRowOfBlock =
            arith::MulIOp::create(rewriter, loc, nIndex, constGemm0NPerBlock);
        if (numRepeatsGQA) {
          Value constNumRepeatsGQA =
              rewriter.createOrFold<arith::ConstantIndexOp>(
                  loc, numRepeatsGQA.getInt());
          maxRowOfBlock = rewriter.createOrFold<arith::DivUIOp>(
              loc, maxRowOfBlock, constNumRepeatsGQA);
        }

        // if effectiveSeqLen is set, it means KV Cache is enabled,
        // so we need to take the minimum of currentSeqLen and maxRowOfBlock
        if (effectiveSeqLen)
          maxRowOfBlock = arith::MinUIOp::create(rewriter, loc, currentSeqLen,
                                                          maxRowOfBlock);
        effectiveSeqLen = maxRowOfBlock;
      }

      // compute end index
      Value constGemm0MPerBlock =
          rewriter.createOrFold<arith::ConstantIndexOp>(loc, gemm0MPerBlock);
      Value numerator = arith::AddIOp::create(rewriter, loc, effectiveSeqLen,
                                                       constGemm0MPerBlock);
      end = rewriter.createOrFold<arith::DivUIOp>(loc, numerator,
                                                  constGemm0MPerBlock);
      Value one = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);

      // start index is zero unless split-kv is enabled
      start = zero;
      if (splitKV != 1) {
        // here, "end" now means number of iterations in total, we need to split
        // those iterations into split-kv blocks.
        // see runEarlyExit() for details about early exit.
        Value constSplitKV =
            rewriter.createOrFold<arith::ConstantIndexOp>(loc, splitKV);
        Value constSplitKVM1 =
            rewriter.createOrFold<arith::ConstantIndexOp>(loc, splitKV - 1);
        Value numerator =
            rewriter.create<arith::AddIOp>(loc, end, constSplitKVM1);
        Value gemm0MIterations =
            rewriter.createOrFold<arith::DivUIOp>(loc, numerator, constSplitKV);

        // if split-kv is enabled, we need to compute the start and end indices.
        start = arith::MulIOp::create(rewriter, loc, gridCoordsGemm0.split_block,
                                       gemm0MIterations);
        Value splitPlusOne = arith::AddIOp::create(rewriter, loc,
                                                    gridCoordsGemm0.split_block, one);
        Value endSplitKV = arith::MulIOp::create(rewriter, loc, splitPlusOne,
                                                  gemm0MIterations);
        end = rewriter.create<arith::MinUIOp>(loc, end, endSplitKV);
      }
      // compute last iteration of the block, this will be used later in
      // setGemm0OutputOutOfScope()
      gemm0MBlocksLastIter =
          rewriter.createOrFold<arith::SubIOp>(loc, end, one);
    } else if (splitKV != 1) {
      // if split-kv is enabled, we need to compute the start and end indices.
      // this is the code for the case where kv-cache and causal are not
      // enabled. the logic is easier, but note that some blocks will early
      // exit, see runEarlyExit() for details.
      Value gemm0MIterations = rewriter.createOrFold<arith::ConstantIndexOp>(
          loc, gemm0M / (gemm0MPerBlock * splitKV));
      Value one = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
      start = rewriter.create<arith::MulIOp>(loc, gridCoordsGemm0.split_block,
                                             gemm0MIterations);
      Value splitPlusOne =
          arith::AddIOp::create(rewriter, loc, gridCoordsGemm0.split_block, one);
      end = arith::MulIOp::create(rewriter, loc, splitPlusOne, gemm0MIterations);
    }
    return std::make_tuple(start, end, gemm0MBlocksLastIter, currentSeqLen);
  }

  LoopLikeOpInterface createMLoop(PatternRewriter &rewriter, Location loc,
                                  Value start, Value end, int64_t gemm0M,
                                  int64_t gemm0MPerBlock,
                                  bool dynamicMLoop) const {
    LoopLikeOpInterface mLoopOp;
    if (dynamicMLoop) {
      Value one = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 1);
      mLoopOp = rewriter.create<scf::ForOp>(loc, start, end, one);
    } else {
      int64_t gemm0MBlocks = gemm0M / gemm0MPerBlock;
      mLoopOp = rewriter.create<affine::AffineForOp>(loc, 0, gemm0MBlocks, 1);
    }
    return mLoopOp;
  }

  void runEarlyExit(PatternRewriter &rewriter, Location loc, Value start,
                    Value end, int64_t splitKV, int64_t gemm0MPerBlock,
                    std::optional<APInt> prePadG0M, bool isCausal,
                    bool isKVCache) const {
    // we need to do early exit if (1) and (2 || 3) conditions are true:
    // 1. split-kv > 1
    // 2. there's padding in gemm0M && (at least) the last block in split-kv
    // dimension has nothing to do
    // 3. (kvcache || causal) && (end <= start)
    if (splitKV == 1)
      return;

    // condition 2: some padding in gemm0M
    // if prePadM < gemm0MPerBlock, then, the last block has some work to do
    bool earlyExitDueToPadding =
        prePadG0M.has_value() &&
        (prePadG0M.value().getSExtValue() >= gemm0MPerBlock);
    // condition 3: causal or kvcache
    bool earlyExitDueToCausalOrKVCache = isCausal || isKVCache;

    if (!earlyExitDueToPadding && !earlyExitDueToCausalOrKVCache)
      return;

    Value someWorkToDo;
    // for dynamic kernels, no need to check padding condition. start/end checks
    // can handle padding as well.
    if (earlyExitDueToCausalOrKVCache) {
      // if end is less than (or equal) start, then we can early exit the split
      // KV loop.
      someWorkToDo = rewriter.create<arith::CmpIOp>(
          loc, arith::CmpIPredicate::ugt, end, start);
    } else if (earlyExitDueToPadding) {
      Value constGemm0MPerBlock =
          rewriter.createOrFold<arith::ConstantIndexOp>(loc, gemm0MPerBlock);
      Value prePadMValue = rewriter.createOrFold<arith::ConstantIndexOp>(
          loc, prePadG0M.value().getSExtValue());
      Value startIteration =
          rewriter.create<arith::MulIOp>(loc, start, constGemm0MPerBlock);

      // if startIteration is less than (or equal) prePadMValue, then we can
      // early exit the split KV loop.
      someWorkToDo = rewriter.create<arith::CmpIOp>(
          loc, arith::CmpIPredicate::ult, startIteration, prePadMValue);
    }
    scf::IfOp ifb = rewriter.create<scf::IfOp>(loc, someWorkToDo,
                                               /*withElseRegion=*/false);
    rewriter.setInsertionPointToStart(&ifb.getThenRegion().front());
  }

  LogicalResult matchAndRewrite(GridwiseAttentionAccelOp op,
                                PatternRewriter &rewriter) const override {
    Location loc = op.getLoc();
    StringRef arch = rock::getArchValue(op);
    uint32_t blockSize = op.getBlockSize();
    uint32_t gridSize = op.getGridSize();

    // Get 'features' from the op
    auto features = rock::getFeatures(op);
    auto featuresAttr = op.getFeaturesAttr();

    TypedValue<MemRefType> inQ = op.getQueries();
    ArrayRef<int64_t> qShape = cast<MemRefType>(inQ.getType()).getShape();
    Type elemTypeQ = cast<MemRefType>(inQ.getType()).getElementType();
    FailureOr<Type> maybeElemTypeQLoad = getInputFusionElementType(inQ);
    Type elemTypeQLoad =
        failed(maybeElemTypeQLoad) ? elemTypeQ : maybeElemTypeQLoad.value();

    TypedValue<MemRefType> inK = op.getKeys();
    ArrayRef<int64_t> kShape = cast<MemRefType>(inK.getType()).getShape();
    Type elemTypeK = cast<MemRefType>(inK.getType()).getElementType();
    FailureOr<Type> maybeElemTypeKLoad = getInputFusionElementType(inK);
    Type elemTypeKLoad =
        failed(maybeElemTypeKLoad) ? elemTypeK : maybeElemTypeKLoad.value();

    TypedValue<MemRefType> inV = op.getValues();
    Type elemTypeV = inV.getType().getElementType();
    FailureOr<Type> maybeElemTypeVLoad = getInputFusionElementType(inV);
    Type elemTypeVLoad =
        failed(maybeElemTypeVLoad) ? elemTypeV : maybeElemTypeVLoad.value();

    TypedValue<MemRefType> out = op.getOut();
    Value trOut = transposeAttnOperand(rewriter, loc, out);
    ArrayRef<int64_t> outShape = cast<MemRefType>(trOut.getType()).getShape();
    Type elemTypeOut = cast<MemRefType>(trOut.getType()).getElementType();

    Value lse = op.getLse();

    TypedValue<MemRefType> currentSeqLenTensor = op.getCurrentSeqLen();
    bool isKVCache = currentSeqLenTensor != nullptr;
    bool isCausal = op.getCausal();
    int64_t splitKV = op.getSplitKV();

    // Gemm0 out is casted to be softmaxType (if null, it's casted to elemTypeV)
    Type elemTypeSoftmax = op.getSoftmaxType().value_or(elemTypeV);

    auto privateMemoryAddressSpace = rewriter.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getPrivateAddressSpace());

    int64_t gemm0G = qShape[0];
    int64_t gemm0K = qShape[1];
    int64_t gemm0N = qShape[2];
    int64_t gemm0M = kShape[2];

    int64_t gemm1M = outShape[1];
    int64_t gemm1N = outShape[2];

    RockAccelTuningParamAttrInterface gemm0TuningParams = op.getParams0();
    RockAccelTuningParamAttrInterface gemm1TuningParams = op.getParams1();
    int64_t gemm0kpack = gemm0TuningParams.getKpack();
    int64_t gemm0KpacksPerBlock = gemm0TuningParams.getKpackPerBlock();
    int64_t gemm0MPerBlock = gemm0TuningParams.getMPerBlock();
    int64_t gemm0NPerBlock = gemm0TuningParams.getNPerBlock();
    bool forceUnroll = gemm0TuningParams.getForceUnroll();
    int64_t gemm0MBlocks = gemm0M / gemm0MPerBlock;
    int64_t gemm0NBlocks = gemm0N / gemm0NPerBlock;
    int64_t gemm1kpack = gemm1TuningParams.getKpack();

    auto accelEmitterPtrGemm0 = accel::AccelEmitter::select(
        features, elemTypeQ, elemTypeK, arch, gemm0TuningParams);
    if (!accelEmitterPtrGemm0)
      return op.emitOpError("Unable to emit accelerator code.");
    bool doBypassLDSSecondGemm = canBypassLDSForSecondGemm(op);
    bool doBypassLDSForQ = canBypassLDSForQ(op);
    rock::accel::AccelEmitterParams accelParamsGemm0 =
        accelEmitterPtrGemm0->getParams();
    auto accelEmitterPtrGemm1 = accel::AccelEmitter::select(
        features, elemTypeV, elemTypeV, arch, gemm1TuningParams);
    if (!accelEmitterPtrGemm1)
      return op.emitOpError("Unable to emit accelerator code.");
    rock::accel::AccelEmitterParams accelParamsGemm1 =
        accelEmitterPtrGemm1->getParams();

    // Get current workgroup ID.
    auto bid = WorkgroupIdOp::create(rewriter, loc, rewriter.getIndexType());
    // Get current workitem ID.
    auto tid = WorkitemIdOp::create(rewriter, loc, rewriter.getIndexType());

    // Calculate different size derivations
    int64_t gemm0KPerBlock = gemm0kpack * gemm0KpacksPerBlock;
    int64_t gemm1KPerBlock = gemm0MPerBlock;
    int64_t gemm1MPerBlock = gemm1TuningParams.getMPerBlock();
    int64_t gemm1NPerBlock = gemm1TuningParams.getNPerBlock();
    // Note that kPerBlock for Gemm1B is mPerBlock of Gemm0 out
    // Note that mPerBlock for Gemm1A is mPerBlock of Gemm0 out
    // Note that nPerBlock for Gemm1B is nPerBlock of Gemm0 out
    int64_t gemm1MBlocks = gemm1M / gemm1MPerBlock;
    int64_t gemm1NBlocks = gemm1N / gemm1NPerBlock;
    assert(gemm0NPerBlock % gemm0kpack == 0 &&
           "nPerBlock should be divisible by kpack");
    int64_t gemm1KpacksPerBlock = gemm1KPerBlock / gemm1kpack;
    SmallVector<int64_t, 3> gemm0BidGridLengths = {gemm0G, gemm0MBlocks,
                                                   gemm0NBlocks};
    FailureOr<VectorDimInfo> maybeVectorDimInfoQ =
        getVectorDim(rewriter, loc, inQ, elemTypeQLoad, blockSize,
                     gemm0KPerBlock, gemm0NPerBlock, gemm0kpack);
    if (failed(maybeVectorDimInfoQ)) {
      return failure();
    }
    LDSLayoutConfigDim ldsLayoutCfgNG0 = getLDSLayoutConfigDim(
        elemTypeQ, gemm0kpack, maybeVectorDimInfoQ.value());
    if (doBypassLDSForQ) {
      ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
    }
    if (op.getEnableSoftmax()) {
      // TODO: Workaround for issue
      // https://github.com/ROCm/rocMLIR-internal/issues/1802 If sumRowBuffer
      // and expMaxDiffRowBuffer are filled with doSwapThreadIterSubDims=true,
      // it does not match with the second GEMM N dimension. Find a good
      // solution to this.
      ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
    }
    FailureOr<VectorDimInfo> maybeVectorDimInfoK =
        getVectorDim(rewriter, loc, inK, elemTypeKLoad, blockSize,
                     gemm0KPerBlock, gemm0MPerBlock, gemm0kpack);
    if (failed(maybeVectorDimInfoK)) {
      return failure();
    }
    LLVM_DEBUG(llvm::dbgs()
               << "elemTypeQLoad: " << elemTypeQLoad << "\n"
               << "elemTypeKLoad: " << elemTypeKLoad << "\n"
               << "elemTypeVLoad: " << elemTypeVLoad << "\n"
               << "qVectorDim: " << maybeVectorDimInfoQ->vectorDim << "\n"
               << "qVectorLen: " << maybeVectorDimInfoQ->vectorLen << "\n"
               << "kVectorDim: " << maybeVectorDimInfoK->vectorDim << "\n"
               << "kVectorLen: " << maybeVectorDimInfoK->vectorLen << "\n");
    LDSLayoutConfigDim ldsLayoutCfgMG0 = getLDSLayoutConfigDim(
        elemTypeK, gemm0kpack, maybeVectorDimInfoK.value());
    ldsLayoutCfgMG0.doRotateWithK = false;
    if (doBypassLDSSecondGemm) {
      ldsLayoutCfgMG0.doSwapThreadIterSubDims = false;
    }
    int64_t gemm0InMPerThread = maybeVectorDimInfoK->inDPerThread;
    int64_t gemm0InNPerThread = maybeVectorDimInfoQ->inDPerThread;
    FailureOr<RegsAsMatrixSubTiles> maybeGemm0OutSubTileViews =
        accelEmitterPtrGemm0->computeOutputTransforms(
            rewriter, loc, gemm0M, gemm0N, blockSize, gemm0BidGridLengths,
            gemm0InMPerThread, gemm0InNPerThread,
            ldsLayoutCfgMG0.doSwapThreadIterSubDims,
            ldsLayoutCfgNG0.doSwapThreadIterSubDims);
    if (failed(maybeGemm0OutSubTileViews)) {
      return failure();
    }
    auto gemm0OutSubTileViews = maybeGemm0OutSubTileViews.value();
    RegsAsMatrixSubTiles gemm0OutSubTileViewsTr =
        transposeSubTileViews(rewriter, loc, gemm0OutSubTileViews);
    int64_t gemm0MPerThread =
        getLowerShape(gemm0OutSubTileViews.threadSubTile)[0];
    int64_t gemm0NPerThread =
        getLowerShape(gemm0OutSubTileViews.threadSubTile)[1];
    int64_t gemm1InNPerThread = gemm0NPerThread;

    // Create shared buffers accross gemms and reductions
    int64_t ldsByteBufferQSize = gemm0KPerBlock * gemm0NPerBlock;
    if (doBypassLDSForQ) {
      ldsByteBufferQSize = 0;
    }
    int64_t reductionWorkspaceSize =
        (gemm0MPerBlock / gemm0MPerThread) * gemm0NPerBlock;
    int64_t gemm1LDSByteBufferBSize = gemm1KPerBlock * gemm1NPerBlock;
    if (doBypassLDSSecondGemm) {
      gemm1LDSByteBufferBSize = 0;
    }

    // Buffers for Gemm0
    Value fromGlobalRegBufferQ;
    Value toLDSRegBufferQ;
    if (doBypassLDSForQ) {
      auto loadBufferType =
          MemRefType::get({accelParamsGemm0.nRepeats *
                           accelParamsGemm0.kpackPerThread * gemm0kpack},
                          elemTypeQ, AffineMap{}, privateMemoryAddressSpace);
      fromGlobalRegBufferQ = GpuAllocOp::create(rewriter, loc, loadBufferType);
    } else {
      std::tie(fromGlobalRegBufferQ, toLDSRegBufferQ) =
          createRegBuffersForGemmIn(loc, gemm0KPerBlock, blockSize, elemTypeQ,
                                    gemm0NPerBlock, rewriter);
    }
    auto [fromGlobalRegBufferK, toLDSRegBufferK] = createRegBuffersForGemmIn(
        loc, gemm0KPerBlock, blockSize, elemTypeK, gemm0MPerBlock, rewriter);
    // Note that we dont provide nRepeats because we dont want
    // nRepeats times reg buffer to be created for B of gemm0
    // because we wont be prefetching that.
    auto [preAccelRegBufferK, preAccelRegBuffersQ] =
        createRegInterrimBufferForAccel(loc, accelParamsGemm0, rewriter, 1,
                                        accelParamsGemm0.nRepeats);
    Value accRegBufferGemm0 =
        createBufferForAccelGemmOut(loc, accelParamsGemm0, rewriter);
    // Currently, there is a working assumption that this kernel is meant
    // support fp32/fp16/bf16. This should be guaranteed by op verifiers.
    Type gemmOutElemType = elemTypeV;
    if (elemTypeQ == rewriter.getI8Type()) {
      gemmOutElemType = rewriter.getI32Type();
    }
    Type fusionOutElemType = elemTypeV;
    op.getPreSoftmaxBody().walk([&](linalg::GenericOp genOp) {
      // Keep visiting to get the fusionOutElement type from the last genOp
      fusionOutElemType =
          cast<ShapedType>(genOp.getOutputs()[0].getType()).getElementType();
    });

    Value gemm0OutBuffer = createBufferForGemmOut(loc, gemmOutElemType,
                                                  accelParamsGemm0, rewriter);
    Value softmaxInputBuffer;
    if (fusionOutElemType != elemTypeSoftmax) {
      softmaxInputBuffer = createBufferForGemmOut(loc, elemTypeSoftmax,
                                                  accelParamsGemm0, rewriter);
    }
    SmallVector<StringRef, 3> bidGridOrder = {"g_block", "m_block", "n_block"};

    Value fusionOutBuffer = createBufferForGemmOut(loc, fusionOutElemType,
                                                   accelParamsGemm0, rewriter);
    // Buffers for reductions and softmax input
    Value softmaxBufferMax, softmaxBufferExp, softmaxBufferSum;
    if (op.getEnableSoftmax()) {
      softmaxBufferMax = createBufferForGemmOut(loc, elemTypeSoftmax,
                                                accelParamsGemm0, rewriter);
      softmaxBufferExp = createBufferForGemmOut(loc, elemTypeSoftmax,
                                                accelParamsGemm0, rewriter);
      softmaxBufferSum = createBufferForGemmOut(loc, elemTypeSoftmax,
                                                accelParamsGemm0, rewriter);
    }
    // Buffers for gemm 1
    Value gemm1RegBufferB;
    if (elemTypeV != elemTypeSoftmax) {
      gemm1RegBufferB =
          createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter);
    }
    Value gemm0ExpOutBufferToLDS =
        createBufferForGemmOut(loc, elemTypeV, accelParamsGemm0, rewriter);
    auto [preAccelRegBufferV, preAccelRegBufferQxK] =
        createRegInterrimBufferForAccel(
            loc, accelParamsGemm1, rewriter, 1,
            doBypassLDSSecondGemm ? accelParamsGemm1.nRepeats : 1);

    Value accRegBufferGemm1;
    Value gemm1OutBuffer;
    if (op.getEnableSoftmax()) {
      accRegBufferGemm1 =
          createBufferForAccelGemmOut(loc, accelParamsGemm1, rewriter);
      gemm1OutBuffer = createBufferForGemmOut(loc, elemTypeSoftmax,
                                              accelParamsGemm1, rewriter);
    } else {
      accRegBufferGemm1 = createBufferForAccelGemmOut(loc, accelParamsGemm1,
                                                      rewriter, gemm1MBlocks);
      gemm1OutBuffer = createBufferForGemmOut(
          loc, elemTypeSoftmax, accelParamsGemm1, rewriter, gemm1MBlocks);
    }

    SmallVector<int64_t, 3> gemm1BidGridLengths = {gemm0G, gemm1MBlocks,
                                                   gemm1NBlocks};
    FailureOr<VectorDimInfo> maybeVectorDimInfoV =
        getVectorDim(rewriter, loc, inV, elemTypeVLoad, blockSize,
                     gemm1KPerBlock, gemm1MPerBlock, gemm1kpack);
    if (failed(maybeVectorDimInfoV)) {
      return failure();
    }
    LLVM_DEBUG(llvm::dbgs()
               << "vVectorDim: " << maybeVectorDimInfoV->vectorDim << "\n"
               << "vVectorLen: " << maybeVectorDimInfoV->vectorLen << "\n");
    LDSLayoutConfigDim ldsLayoutCfgMG1 = getLDSLayoutConfigDim(
        elemTypeV, gemm1kpack, maybeVectorDimInfoV.value());
    int64_t gemm1InMPerThread = maybeVectorDimInfoV->inDPerThread;
    FailureOr<RegsAsMatrixSubTiles> maybeGemm1OutSubTileViews =
        accelEmitterPtrGemm1->computeOutputTransforms(
            rewriter, loc, gemm1M, gemm1N, blockSize, gemm1BidGridLengths,
            gemm1InMPerThread, gemm1InNPerThread,
            ldsLayoutCfgMG1.doSwapThreadIterSubDims);
    if (failed(maybeGemm1OutSubTileViews)) {
      return failure();
    }
    auto gemm1OutSubTileViews = maybeGemm1OutSubTileViews.value();
    RegsAsMatrixSubTiles gemm1OutSubTileViewsTr =
        transposeSubTileViews(rewriter, loc, gemm1OutSubTileViews);
    auto [fromGlobalRegBufferV, toLDSRegBufferV] = createRegBuffersForGemmIn(
        loc, gemm1KPerBlock, blockSize, elemTypeV, gemm1MPerBlock, rewriter);
    int64_t gemm1MPerThread =
        getLowerShape(gemm1OutSubTileViewsTr.threadSubTile)[0];

    // Buffers for running row state

    // o buffer; this is exactly same as gemm1OutBuffer;
    // we just need another buffer to do the special accumulation
    Value attentionOutAccBuffer, outAccBufferOutTyped, sumRowBuffer,
        maxRowBuffer, expMaxDiffRowBuffer, lseBuffer;
    ArrayAttr attentionOutAccBufferThreadSubTileViewMaps;
    if (op.getEnableSoftmax()) {
      attentionOutAccBuffer = createBufferForGemmOut(
          loc, elemTypeSoftmax, accelParamsGemm1, rewriter, gemm1MBlocks);
      outAccBufferOutTyped = attentionOutAccBuffer;
      if (elemTypeSoftmax != elemTypeOut) {
        outAccBufferOutTyped = createBufferForGemmOut(
            loc, elemTypeOut, accelParamsGemm1, rewriter, gemm1MBlocks);
      }
      attentionOutAccBufferThreadSubTileViewMaps =
          invertTransforms(rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile);
      // m buffer; this only contains a reduced single value per row
      auto reducedBufferType =
          MemRefType::get({gemm1MPerThread}, elemTypeSoftmax, AffineMap{},
                          /*memorySpace=*/privateMemoryAddressSpace);
      auto negInfSumTyped = createConstantFloatOp(
          rewriter, loc, reducedBufferType.getElementType(),
          reducedBufferType.getElementType(),
          -std::numeric_limits<float>::infinity(), APFloat::opOK);
      maxRowBuffer = rock::GpuAllocOp::create(rewriter, loc, reducedBufferType);
      expMaxDiffRowBuffer =
          rock::GpuAllocOp::create(rewriter, loc, reducedBufferType);
      FillOp::create(rewriter, loc, maxRowBuffer, negInfSumTyped);
      // l buffer; this only contains a reduced single value per row
      sumRowBuffer = rock::GpuAllocOp::create(rewriter, loc, reducedBufferType);
      FillOp::create(rewriter, loc, sumRowBuffer,
                     createZeroConstantOp(rewriter, loc, elemTypeSoftmax));
      if (lse) {
        Type elemTypeLse = cast<MemRefType>(lse.getType()).getElementType();
        lseBuffer = createBufferForGemmOut(loc, elemTypeLse, accelParamsGemm1,
                                           rewriter);
      }

      zeroAccBuffer(rewriter, loc, attentionOutAccBuffer);
    } else {
      outAccBufferOutTyped = gemm1OutBuffer;
      if (elemTypeSoftmax != elemTypeOut) {
        outAccBufferOutTyped = createBufferForGemmOut(
            loc, elemTypeOut, accelParamsGemm1, rewriter, gemm1MBlocks);
      }
      zeroAccBuffer(rewriter, loc, accRegBufferGemm1);
    }

    // if splitKV == 1, we define nullptr, and makeGxNGridLayout() will use
    // fewer instructions
    Value splitKVConst =
        (splitKV > 1) ? rewriter.createOrFold<ConstantIndexOp>(loc, splitKV)
                      : nullptr;
    auto gridCoordsGemm0mIter0 = layout::makeGxNGridLayout(
        rewriter, loc, bid,
        rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0), gemm0NBlocks,
        gridSize, arch, splitKVConst);

    Value gemm0MBlocksLastIter;
    Value currentSeqLen;
    Value start, end;
    // get mLoop
    std::tie(start, end, gemm0MBlocksLastIter, currentSeqLen) =
        getMLoopInfo(rewriter, loc, gridCoordsGemm0mIter0, currentSeqLenTensor,
                     gemm0M, gemm0MPerBlock, gemm0NPerBlock, splitKV, isCausal,
                     isKVCache, op.getNumRepeatsGQAAttr());

    // early exist if there is no work to do for this block
    runEarlyExit(rewriter, loc, start, end, splitKV, gemm0MPerBlock,
                 op.getPrePadG0M(), isCausal, isKVCache);

    // If gemm0K is equal to gemm0KPerBlock that means
    // effectively there is no K loop. Therefore, we
    // can prefetch the Q tile into regs outside of the
    // loop.
    if (gemm0K == gemm0KPerBlock) {
      LLVM_DEBUG(llvm::dbgs()
                 << "rock.attention: gemm0K is equal to gemm0KPerBlock\n");
      LLVM_DEBUG(llvm::dbgs()
                 << "rock.attention: Prefetching Q tile into regs...\n");
      Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);
      // it is fine m iteration to be zero as it irrelevant to Q tensor
      // as the first gemm is Kt x Qt.
      auto gridCoordsGemm0LoadQ = layout::makeGxNGridLayout(
          rewriter, loc, bid, zero, gemm0NBlocks, gridSize, arch, splitKVConst);

      if (doBypassLDSForQ) {
        LogicalResult statusLoadQTile = loadAndStoreGemmInputTile(
            loc, inQ, /*kiter=*/zero, elemTypeQLoad, gridCoordsGemm0LoadQ,
            fromGlobalRegBufferQ, toLDSRegBufferQ, preAccelRegBuffersQ, "n",
            gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize,
            gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter,
            *accelEmitterPtrGemm0, ldsLayoutCfgNG0, false);
        if (failed(statusLoadQTile)) {
          return failure();
        }
      } else {
        Value ldsByteBufferQ =
            createLDSByteBuffer(rewriter, loc, ldsByteBufferQSize, elemTypeQ);
        LogicalResult statusLoadQ = loadAndStoreGemmInputTile(
            loc, inQ, /*kiter=*/zero, elemTypeQLoad, gridCoordsGemm0LoadQ,
            fromGlobalRegBufferQ, toLDSRegBufferQ, ldsByteBufferQ, "n",
            gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize,
            gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter,
            *accelEmitterPtrGemm0, ldsLayoutCfgNG0, false);
        if (failed(statusLoadQ)) {
          return failure();
        }
        LDSBarrierOp::create(rewriter, loc);

        TypedValue<MemRefType> ldsTileBufferQ = viewBufferAs(
            rewriter, ldsByteBufferQ, vectorTypeOrSelf(elemTypeQ, gemm0kpack));
        loadGemmOperandsFromLDSToRegs(rewriter, loc, ldsTileBufferQ,
                                      preAccelRegBuffersQ, "n", blockSize,
                                      gemm0InNPerThread, *accelEmitterPtrGemm0,
                                      ldsLayoutCfgNG0.doRotateWithK);
        GpuDeallocOp::create(rewriter, loc, ldsByteBufferQ);
      }
    }

    bool dynamicMLoop = splitKV != 1 || isCausal || isKVCache;
    LoopLikeOpInterface mLoopOp = createMLoop(rewriter, loc, start, end, gemm0M,
                                              gemm0MPerBlock, dynamicMLoop);
    {
      PatternRewriter::InsertionGuard guard(rewriter);
      // workaround for mLoopOp.getBody()
      assert(mLoopOp->getRegions().size() == 1);
      rewriter.setInsertionPointToStart(&mLoopOp->getRegion(0).front());
      int64_t kIterationsGemm0 = gemm0K / gemm0KPerBlock;
      Value mLoopIV = mLoopOp.getSingleInductionVar().value();
      zeroAccBuffer(rewriter, loc, accRegBufferGemm0);
      auto gridCoordsGemm0 =
          layout::makeGxNGridLayout(rewriter, loc, bid, mLoopIV, gemm0NBlocks,
                                    gridSize, arch, splitKVConst);
      affine::AffineForOp kLoopOp =
          affine::AffineForOp::create(rewriter, loc, 0, kIterationsGemm0, 1);
      {
        PatternRewriter::InsertionGuard guard(rewriter);
        rewriter.setInsertionPointToStart(kLoopOp.getBody());
        Value kLoopIV = kLoopOp.getInductionVar();

        // LDS Barrier (issue 1811): some threads might be loading from LDS
        // while others are in the next iteration (here), writing to LDS. This
        // barrier prevents that.
        std::optional<uint64_t> mLoopIters = std::nullopt;
        // mLoopOp can be a scf::ForOp if we are using KV Cache or Causal
        // masking. If that's the case, we can't know the number of iterations
        // at compile time.
        if (auto mLoopAffineFor =
                dyn_cast<affine::AffineForOp>(mLoopOp.getOperation()))
          mLoopIters = mlir::affine::getConstantTripCount(mLoopAffineFor);

        bool mIterOneIter = mLoopIters.has_value() && mLoopIters.value() == 1;
        auto kLoopIters = mlir::affine::getConstantTripCount(kLoopOp);
        bool kIterOneIter = kLoopIters.has_value() && kLoopIters.value() == 1;
        // no need to have the barrier if there's just one iteration
        bool addBarrierFirstGemm = !kIterOneIter || !mIterOneIter;
        if (addBarrierFirstGemm)
          LLVM_DEBUG(llvm::dbgs()
                     << "adding a barrier in the first gemm loop\n");

        // if gemm0K is equal to gemm0KPerBlock, the Q tile
        // is already prefetched into regs. See above.
        TypedValue<MemRefType> ldsTileBufferQ;
        Value ldsByteBufferQ;
        if (gemm0K != gemm0KPerBlock) {
          ldsByteBufferQ =
              createLDSByteBuffer(rewriter, loc, ldsByteBufferQSize, elemTypeQ);
          LogicalResult statusLoadQ = loadAndStoreGemmInputTile(
              loc, inQ, kLoopIV, elemTypeQLoad, gridCoordsGemm0,
              fromGlobalRegBufferQ, toLDSRegBufferQ, ldsByteBufferQ, "n",
              gemm0kpack, gemm0KpacksPerBlock, gemm0NPerBlock, blockSize,
              gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll,
              rewriter, *accelEmitterPtrGemm0, ldsLayoutCfgNG0,
              addBarrierFirstGemm);
          if (failed(statusLoadQ)) {
            return failure();
          }
          // no need to add a barrier in the next call to
          // loadAndStoreGemmInputTile()
          addBarrierFirstGemm = false;
          ldsTileBufferQ =
              viewBufferAs(rewriter, ldsByteBufferQ,
                           vectorTypeOrSelf(elemTypeQ, gemm0kpack));
        }
        // if we added a barrier in the previous block (load Q), there's no need
        // to add it again here.
        Value ldsByteBufferK = createLDSByteBuffer(
            rewriter, loc, gemm0KPerBlock * gemm0MPerBlock, elemTypeK);
        LogicalResult statusLoadKTile = loadAndStoreGemmInputTile(
            loc, inK, kLoopIV, elemTypeKLoad, gridCoordsGemm0,
            fromGlobalRegBufferK, toLDSRegBufferK, ldsByteBufferK, "m",
            gemm0kpack, gemm0KpacksPerBlock, gemm0MPerBlock, blockSize,
            gridSize, bidGridOrder, gemm0BidGridLengths, forceUnroll, rewriter,
            *accelEmitterPtrGemm0, ldsLayoutCfgMG0, addBarrierFirstGemm);
        if (failed(statusLoadKTile)) {
          return failure();
        }
        TypedValue<MemRefType> ldsTileBufferK = viewBufferAs(
            rewriter, ldsByteBufferK, vectorTypeOrSelf(elemTypeK, gemm0kpack));
        // LDS barrier.
        LDSBarrierOp::create(rewriter, loc);
        // if gemm0K is equal to gemm0KPerBlock, the Q tile
        // is already prefetched into regs. See above.
        if (gemm0K != gemm0KPerBlock) {
          loadGemmOperandsFromLDSToRegs(
              rewriter, loc, ldsTileBufferQ, preAccelRegBuffersQ, "n",
              blockSize, gemm0InNPerThread, *accelEmitterPtrGemm0,
              ldsLayoutCfgNG0.doRotateWithK);
          GpuDeallocOp::create(rewriter, loc, ldsByteBufferQ);
        }

        // Emit lowered blockwise GEMM 0.
        BlockwiseGemmAccelOp::create(
            rewriter, loc, ldsTileBufferK,
            ldsTileBufferQ ? ldsTileBufferQ : ldsTileBufferK,
            rewriter.getI32IntegerAttr(gemm0InMPerThread),
            rewriter.getI32IntegerAttr(gemm0InNPerThread),
            /*rotateMWithK=*/nullptr,
            (ldsLayoutCfgNG0.doRotateWithK ? rewriter.getUnitAttr() : nullptr),
            /*loadAfromLDS=*/rewriter.getUnitAttr(), /*loadBfromLDS=*/nullptr,
            /*splitKAcrossThreadsFirstA=*/nullptr,
            /*splitKAcrossThreadsFirstB=*/nullptr, preAccelRegBufferK,
            preAccelRegBuffersQ, accRegBufferGemm0, featuresAttr,
            op.getBlockSizeAttr(), gemm0TuningParams);

        GpuDeallocOp::create(rewriter, loc, ldsByteBufferK);
      }
      accelEmitterPtrGemm0->computeOutputConversion(
          rewriter, loc, accRegBufferGemm0, gemm0OutBuffer, forceUnroll);

      int64_t prePadG0M = gemm0M;
      if (op.getPrePadG0M().has_value()) {
        prePadG0M = op.getPrePadG0M().value().getSExtValue();
      }
      int64_t prePadG0N = gemm0N;
      if (op.getPrePadG0N().has_value()) {
        prePadG0N = op.getPrePadG0N().value().getSExtValue();
      }
      RegsAsMatrixSubTiles gemm0OutSubTileViewsTrUnPadded =
          unpadGridSubTileView(rewriter, loc, gemm0OutSubTileViewsTr, prePadG0N,
                               prePadG0M);

      // undo Grouped-Query Attention (GQA) transforms
      // This is needed because the preSoftmaxElementWise inputs (if any), don't
      // have the GQA transformed applied to them. So, we undo the transform to
      // the output of the first GEMM. See postProcessFirstGemm() to understand
      // the transforms done to preSoftmaxElementWise inputs.
      ArrayRef<int64_t> unpaddedShape =
          getLowerShape(gemm0OutSubTileViewsTrUnPadded.gridSubTile);
      ArrayAttr undoGQA = undoGQATransforms(rewriter, loc, op, unpaddedShape);

      // undo the GQA transforms for postProcessFirstGemm()
      if (undoGQA) {
        ArrayAttr linalgGridSubTileMaps =
            gemm0OutSubTileViewsTrUnPadded.gridSubTile;
        linalgGridSubTileMaps =
            prependUpperViews(rewriter, linalgGridSubTileMaps, undoGQA);
        gemm0OutSubTileViewsTrUnPadded.gridSubTile = linalgGridSubTileMaps;
      }

      // Align the preSoftmaxElementWise (if any) linalg.generic to
      // be performed on the output of the first gemm.
      FailureOr<Value> maybeFusionOutBuffer = postProcessFirstGemm(
          rewriter, loc, op, gridCoordsGemm0, gemm0OutBuffer, fusionOutBuffer,
          gemm0OutSubTileViewsTrUnPadded);
      if (failed(maybeFusionOutBuffer)) {
        return op.emitError("post processing first gemm failed.\n");
      }
      gemm0OutBuffer = maybeFusionOutBuffer.value();
      if (fusionOutElemType == elemTypeSoftmax)
        softmaxInputBuffer = gemm0OutBuffer;

      // Softmax
      if (op.getEnableSoftmax()) {
        // convert gemm0OutBuffer to elemTypeSoftmax
        if (fusionOutElemType != elemTypeSoftmax) {
          createTypeConversionFlatAndStore(rewriter, loc, gemm0OutBuffer,
                                           softmaxInputBuffer);
        }
        // Scale gemm0 output by (1/ln2)
        // So that we can use exp2 instead of exp.
        Value ln2Recip = createConstantFloatOp(
            rewriter, loc, elemTypeSoftmax, elemTypeSoftmax, 1.44269504f,
            elemTypeSoftmax.getIntOrFloatBitWidth() >= 32 ? APFloat::opOK
                                                          : APFloat::opInexact);
        postProcessFirstGemmSplat<ElementwiseMultOp>(
            rewriter, loc, gridCoordsGemm0, softmaxInputBuffer,
            gemm0OutSubTileViews,
            ln2Recip.getDefiningOp<arith::ConstantOp>().getValue());

        // Handle padding
        bool hasPadding =
            op.getPrePadG0M().has_value() || op.getPrePadG0N().has_value();
        if (hasPadding) {
          createFirstGemmNegInfPadding(rewriter, loc, gridCoordsGemm0,
                                       softmaxInputBuffer,
                                       gemm0OutSubTileViewsTrUnPadded);
        }
        // Negative Infinite for extra values (KV cache)
        setGemm0OutputOutOfScope(rewriter, loc, OutOfScopeType::KVCache,
                                 gridCoordsGemm0, softmaxInputBuffer,
                                 gemm0OutSubTileViewsTr, isKVCache, mLoopIV,
                                 gemm0MBlocksLastIter, currentSeqLen);

        // Negative Infinite for extra values (causal masking)
        setGemm0OutputOutOfScope(
            rewriter, loc, OutOfScopeType::Causal, gridCoordsGemm0,
            softmaxInputBuffer, gemm0OutSubTileViewsTr, isCausal, mLoopIV,
            gemm0MBlocksLastIter, /*currentSeqLen=*/nullptr,
            op.getNumRepeatsGQAAttr());

        APInt reductionAxis = APInt(64, 1);
        // Softmax max reduction
        Value ldsReductionWorkspaceByteBuffer = createLDSByteBuffer(
            rewriter, loc, reductionWorkspaceSize, elemTypeSoftmax);
        TypedValue<MemRefType> ldsReductionWorkspaceBuffer = viewBufferAs(
            rewriter, ldsReductionWorkspaceByteBuffer, elemTypeSoftmax);
        BlockwiseBroadcastReduceOp::create(
            rewriter, loc, softmaxInputBuffer, ldsReductionWorkspaceBuffer,
            softmaxBufferMax,
            /*extraOut=*/nullptr, reductionAxis, rock::ReduceMethod::Max,
            gemm0OutSubTileViewsTr.blockSubTile,
            gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(),
            gemm0OutSubTileViewsTr.threadSubTile, /*extraViews=*/nullptr,
            blockSize);
        GpuDeallocOp::create(rewriter, loc, ldsReductionWorkspaceByteBuffer);

        // softmax normalization.
        Value gemm0MNThreadwiseView =
            transform(rewriter, softmaxInputBuffer,
                      invertTransforms(rewriter, loc,
                                       gemm0OutSubTileViewsTr.threadSubTile));
        Value gemm0MNExpThreadwiseView =
            transform(rewriter, softmaxBufferExp,
                      invertTransforms(rewriter, loc,
                                       gemm0OutSubTileViewsTr.threadSubTile));
        Value gemm0MNMaxThreadwiseView =
            transform(rewriter, softmaxBufferMax,
                      invertTransforms(rewriter, loc,
                                       gemm0OutSubTileViewsTr.threadSubTile));
        expSubstractMaxFromGemm0(rewriter, loc, gemm0MNThreadwiseView,
                                 gemm0MNExpThreadwiseView,
                                 gemm0MNMaxThreadwiseView, maxRowBuffer);

        // Softmax sum reduction
        Value ldsReductionWorkspaceByteSecondBuffer = createLDSByteBuffer(
            rewriter, loc, reductionWorkspaceSize, elemTypeSoftmax);
        TypedValue<MemRefType> ldsReductionWorkspaceSecondBuffer = viewBufferAs(
            rewriter, ldsReductionWorkspaceByteSecondBuffer, elemTypeSoftmax);
        BlockwiseBroadcastReduceOp::create(
            rewriter, loc, softmaxBufferExp, ldsReductionWorkspaceSecondBuffer,
            softmaxBufferSum, /*extraOut=*/nullptr, reductionAxis,
            rock::ReduceMethod::Sum, gemm0OutSubTileViewsTr.blockSubTile,
            gemm0OutSubTileViewsTr.blockSubTileTidSlice.value(),
            gemm0OutSubTileViewsTr.threadSubTile,
            /*extraViews=*/nullptr, blockSize);
        GpuDeallocOp::create(rewriter, loc,
                             ldsReductionWorkspaceByteSecondBuffer);
        Value gemm0SumThreadwiseView =
            transform(rewriter, softmaxBufferSum,
                      invertTransforms(rewriter, loc,
                                       gemm0OutSubTileViewsTr.threadSubTile));
        Value gemm0MaxThreadwiseView =
            transform(rewriter, softmaxBufferMax,
                      invertTransforms(rewriter, loc,
                                       gemm0OutSubTileViewsTr.threadSubTile));
        updateRowSum(rewriter, loc, gemm0SumThreadwiseView,
                     gemm0MaxThreadwiseView, sumRowBuffer, maxRowBuffer,
                     expMaxDiffRowBuffer);
      }

      // Emit blockwise GEMM 1.
      {
        auto gemm0Out =
            op.getEnableSoftmax() ? softmaxBufferExp : softmaxInputBuffer;
        if (elemTypeV != elemTypeSoftmax) {
          createTypeConversionFlatAndStore(rewriter, loc, gemm0Out,
                                           gemm1RegBufferB);
        } else {
          gemm1RegBufferB = gemm0Out;
        }
        Value wrappedLDSBufferForLoadB;
        Value gemm1LDSByteBufferB;
        TypedValue<MemRefType> gemm1LDSBufferB;
        if (!doBypassLDSSecondGemm) {
          // The output RegsAsSubTile views are N x M where N is reduction dim
          RegsAsMatrixSubTiles gemm0OutSubTileNxMViews = gemm0OutSubTileViews;
          ArrayAttr gemm0ThreadwiseSubtileViewNxMMaps = invertTransforms(
              rewriter, loc, gemm0OutSubTileNxMViews.threadSubTile);
          Value gemm0ExpNMThreadwiseView = transform(
              rewriter, gemm1RegBufferB, gemm0ThreadwiseSubtileViewNxMMaps);
          // Correct the below toLDSViews to be max LDS vectorizable
          // (For now just hacked in the existing view)
          // Copy copyKPerThread is set to 1 because
          // K is not packed as kpack vectors. Therefore, setting
          // copyKPerThread to be 1 will always make the LDS write
          // to be scalars -- which makes the following layout agnostic.
          // We should get rid of storing to LDS altogether with
          // the transposed layout for this gemm.
          gemm1LDSByteBufferB = createLDSByteBuffer(
              rewriter, loc, gemm1LDSByteBufferBSize, elemTypeV);

          LogicalResult storeGemm1ATileStatus = storeGemmInputTile(
              rewriter, loc, gemm1kpack, gemm0ExpNMThreadwiseView,
              gemm0OutSubTileNxMViews, gemm0ExpOutBufferToLDS,
              gemm1LDSByteBufferB, gemm1KpacksPerBlock, "n", gemm1KPerBlock,
              gemm1NPerBlock, /*copyKPerThread=*/1, gemm1InNPerThread,
              forceUnroll, false, false);
          if (failed(storeGemm1ATileStatus)) {
            return failure();
          }
          gemm1LDSBufferB =
              viewBufferAs(rewriter, gemm1LDSByteBufferB,
                           vectorTypeOrSelf(elemTypeV, gemm1kpack));
          wrappedLDSBufferForLoadB = accelEmitterPtrGemm1->wrapLDSBufferForLoad(
              rewriter, loc, gemm1LDSBufferB, op.getBlockSize(),
              gemm1InNPerThread, "n", false);
        }

        affine::AffineForOp g1MLoopOp =
            affine::AffineForOp::create(rewriter, loc, 0, gemm1MBlocks, 1);
        {
          OpBuilder::InsertionGuard guard(rewriter);
          rewriter.setInsertionPointToStart(g1MLoopOp.getBody());
          Value g1MLoopIndVar = g1MLoopOp.getInductionVar();
          if (op.getEnableSoftmax()) {
            zeroAccBuffer(rewriter, loc, accRegBufferGemm1);
          } else {
            if (gemm1MBlocks > 1) {
              accRegBufferGemm1 = createSliceOfFirstDim(
                  rewriter, loc, accRegBufferGemm1, g1MLoopIndVar);
            }
          }
          auto gridCoordsGemm1 = layout::makeGxNGridLayout(
              rewriter, loc, bid, g1MLoopIndVar, gemm1NBlocks, gridSize, arch,
              splitKVConst);

          Value ldsByteBufferV = createLDSByteBuffer(
              rewriter, loc, gemm1KPerBlock * gemm1MPerBlock, elemTypeV);

          // LDS Barrier (issue 1811): some threads might be loading from LDS
          // while others are in the next iteration (here), writing to LDS. This
          // barrier prevents that. No need to have the barrier if there's just
          // one iteration
          auto g1MLoopIters = mlir::affine::getConstantTripCount(g1MLoopOp);
          bool g1MIterOneIter =
              g1MLoopIters.has_value() && g1MLoopIters.value() == 1;
          bool addBarrierSecondGemm = !g1MIterOneIter;
          if (addBarrierSecondGemm)
            LLVM_DEBUG(llvm::dbgs()
                       << "adding a barrier in the second gemm loop\n");

          LogicalResult statusLoadVTile = loadAndStoreGemmInputTile(
              loc, inV,
              /*kIter=*/mLoopIV, elemTypeVLoad, gridCoordsGemm1,
              fromGlobalRegBufferV, toLDSRegBufferV, ldsByteBufferV, "m",
              gemm1kpack, gemm1KpacksPerBlock, gemm1MPerBlock, blockSize,
              gridSize, bidGridOrder, gemm1BidGridLengths, forceUnroll,
              rewriter, *accelEmitterPtrGemm1, ldsLayoutCfgMG1,
              addBarrierSecondGemm);
          if (failed(statusLoadVTile)) {
            return failure();
          }
          TypedValue<MemRefType> ldsTileBufferV =
              viewBufferAs(rewriter, ldsByteBufferV,
                           vectorTypeOrSelf(elemTypeV, gemm1kpack));
          // LDS barrier.
          LDSBarrierOp::create(rewriter, loc);
          // Emit GEMM 1.

          if (doBypassLDSSecondGemm) {
            ArrayAttr gemm1ThreadwiseSubtileViewDxKMaps = invertTransforms(
                rewriter, loc, gemm0OutSubTileViewsTr.threadSubTile);
            Value gemm1BDxKThreadwiseView = transform(
                rewriter, gemm1RegBufferB, gemm1ThreadwiseSubtileViewDxKMaps);
            affine::AffineForOp nRepeatsLoop = affine::AffineForOp::create(
                rewriter, loc, 0, accelParamsGemm1.nRepeats, 1);
            {
              PatternRewriter::InsertionGuard guard(rewriter);
              rewriter.setInsertionPointToStart(nRepeatsLoop.getBody());
              Value ni = nRepeatsLoop.getInductionVar();
              Value subview = preAccelRegBufferQxK;
              if (accelParamsGemm1.nRepeats > 1) {
                subview = createSliceOfFirstDim(rewriter, loc,
                                                preAccelRegBufferQxK, ni);
              }
              ThreadwiseReadIntoOp::create(
                  rewriter, loc, gemm1BDxKThreadwiseView, subview,
                  rewriter.getArrayAttr({}), ValueRange{ni}, true, true);
            }
          }

          BlockwiseGemmAccelOp::create(
              rewriter, loc, ldsTileBufferV,
              gemm1LDSBufferB ? gemm1LDSBufferB : ldsTileBufferV,
              rewriter.getI32IntegerAttr(gemm1InMPerThread),
              rewriter.getI32IntegerAttr(gemm1InNPerThread),
              (ldsLayoutCfgMG1.doRotateWithK ? rewriter.getUnitAttr()
                                             : nullptr),
              /*rotateNWithK=*/nullptr,
              /*loadAfromLDS=*/rewriter.getUnitAttr(),
              /*loadBfromLDS=*/
              !doBypassLDSSecondGemm ? rewriter.getUnitAttr() : nullptr,
              /*splitKAcrossThreadsFirstA=*/
              doBypassLDSSecondGemm ? rewriter.getUnitAttr() : nullptr,
              /*splitKAcrossThreadsFirstB=*/nullptr, preAccelRegBufferV,
              preAccelRegBufferQxK, accRegBufferGemm1, featuresAttr,
              op.getBlockSizeAttr(), gemm1TuningParams);

          GpuDeallocOp::create(rewriter, loc, ldsByteBufferV);
          if (!doBypassLDSSecondGemm)
            GpuDeallocOp::create(rewriter, loc, gemm1LDSByteBufferB);

          // There is no second k-loop
          // Therefore can get the output straight away
          Value gemm1OutBufferPerG1MBlock = gemm1OutBuffer;
          if (!op.getEnableSoftmax() && gemm1MBlocks > 1) {
            gemm1OutBufferPerG1MBlock = createSliceOfFirstDim(
                rewriter, loc, gemm1OutBuffer, g1MLoopIndVar);
          }

          accelEmitterPtrGemm1->computeOutputConversion(
              rewriter, loc, accRegBufferGemm1, gemm1OutBufferPerG1MBlock,
              forceUnroll);
          if (op.getEnableSoftmax()) {
            Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer;
            if (gemm1MBlocks > 1) {
              attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim(
                  rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar);
            }
            ArrayAttr invertedGemm1threadSubTileMaps = invertTransforms(
                rewriter, loc, gemm1OutSubTileViewsTr.threadSubTile);
            Value gemm1MNThreadwiseView =
                transform(rewriter, gemm1OutBufferPerG1MBlock,
                          invertedGemm1threadSubTileMaps);
            // Rescale/correct output, rowMax and rowSums
            Value attentionOutAccBufferView =
                transform(rewriter, attentionOutAccBufferPerG1MBlock,
                          attentionOutAccBufferThreadSubTileViewMaps);
            createAttentionRowStateCorrections(
                rewriter, loc, gemm1MNThreadwiseView, attentionOutAccBufferView,
                expMaxDiffRowBuffer);
          }
        }
      }
    }

    if (op.getEnableSoftmax()) {
      affine::AffineForOp g1MLoopOp =
          affine::AffineForOp::create(rewriter, loc, 0, gemm1MBlocks, 1);
      {
        OpBuilder::InsertionGuard guard(rewriter);
        rewriter.setInsertionPointToStart(g1MLoopOp.getBody());
        Value g1MLoopIndVar = g1MLoopOp.getInductionVar();
        Value attentionOutAccBufferPerG1MBlock = attentionOutAccBuffer;
        if (gemm1MBlocks > 1) {
          attentionOutAccBufferPerG1MBlock = createSliceOfFirstDim(
              rewriter, loc, attentionOutAccBuffer, g1MLoopIndVar);
        }
        Value attentionOutAccBufferView =
            transform(rewriter, attentionOutAccBufferPerG1MBlock,
                      attentionOutAccBufferThreadSubTileViewMaps);
        scaleFinalOutput(rewriter, loc, attentionOutAccBufferView,
                         sumRowBuffer);
      }
    }
    Value outAccBuffer =
        op.getEnableSoftmax() ? attentionOutAccBuffer : gemm1OutBuffer;
    if (elemTypeSoftmax != elemTypeOut) {
      // We flatten output buffer in case gemm1MBlocks > 1
      // where those are iterated.
      createTypeConversionFlatAndStore(rewriter, loc, outAccBuffer,
                                       outAccBufferOutTyped);
    }
    if (lse) {
      // it must be guaranteed by the verifier
      assert(op.getEnableSoftmax());
      assert(lseBuffer);
      Value lseBufferView = transform(
          rewriter, lseBuffer, attentionOutAccBufferThreadSubTileViewMaps);
      computeLse(rewriter, loc, lseBufferView, sumRowBuffer, maxRowBuffer);
    }

    MemRefType outAccBufferOutType =
        cast<MemRefType>(outAccBufferOutTyped.getType());
    int64_t numElementsAttnOut = outAccBufferOutType.getNumElements();
    // This map will create an upper view [gblock, nblock, flatiter] -> [gblock,
    // miter, nblock, iter]
    TransformMapAttr flatToMiterMap =
        getFlatToMiterMap(rewriter, gemm0G, gemm1MBlocks, gemm1NBlocks,
                          blockSize, numElementsAttnOut);
    ArrayAttr outGridSubTile =
        prependUpperViews(rewriter, rewriter.getArrayAttr({flatToMiterMap}),
                          gemm1OutSubTileViews.gridSubTile);
    Value zero = rewriter.createOrFold<ConstantIndexOp>(loc, 0);

    // Note that we don't use splitKV here because that dimension belongs to the
    // batch size already for output tensors
    auto gridCoordsGemm1 = layout::makeGxNGridLayout(
        rewriter, loc, bid, zero, gemm1NBlocks, gridSize, arch);
    Value outAccBufferOutTypedFlat =
        getFlattenedMemref(rewriter, outAccBufferOutTyped);
    ThreadwiseWriteAllOp::create(
        rewriter, loc, outAccBufferOutTypedFlat, trOut, outGridSubTile,
        /*extraIndices=*/
        ValueRange{gridCoordsGemm1.g_block, gridCoordsGemm1.n_block, tid},
        op.getStoreMethod(), forceUnroll,
        /*useIndexDiffs=*/true);

    // store LSE to device memory
    if (lse) {
      // drop gemmM dimension
      TopDownTMBuilder viewBuilder(rewriter, {"gemmG", "gemmM", "gemmN"},
                                   {gemm0G, gemm1M, gemm1N});
      viewBuilder.passThrough({"gemmG", "gemmN"}, {0, 1}, {"gemmG", "gemmN"});
      viewBuilder.ignore("gemmM");
      auto dropM = rewriter.getArrayAttr({viewBuilder.get()});

      MemRefType lseBufferOutType = cast<MemRefType>(lseBuffer.getType());
      int64_t numElementsLseOut = lseBufferOutType.getNumElements();
      auto flatToMiterMapAttr = getFlatToMiterMap(
          rewriter, gemm0G, 1, gemm1NBlocks, blockSize, numElementsLseOut);
      // slice mIter
      BottomUpTMBuilder sliceBuilder(
          rewriter, {"g_block", "mIter", "n_block", "tid", "iter"},
          {gemm0G, gemm1MBlocks, gemm1NBlocks, blockSize, numElementsLseOut},
          loc);
      sliceBuilder.passThrough({"g_block", "n_block", "tid", "iter"},
                               {0, 2, 3, 4},
                               {"g_block", "n_block", "tid", "iter"});
      sliceBuilder.slice({"mIter"}, {"mIter"}, {0}, {1});
      auto sliceAttr = sliceBuilder.get();

      ArrayAttr flatToMiterSlice = prependUpperViews(
          rewriter, rewriter.getArrayAttr({flatToMiterMapAttr}),
          rewriter.getArrayAttr({sliceAttr}));
      ArrayAttr outGridSubTile = prependUpperViews(
          rewriter, flatToMiterSlice, gemm1OutSubTileViews.gridSubTile);
      ArrayAttr lseMap = prependUpperViews(rewriter, outGridSubTile, dropM);
      ThreadwiseWriteAllOp::create(
          rewriter, loc, lseBuffer, lse, lseMap,
          /*extraIndices=*/
          ValueRange{gridCoordsGemm1.g_block, gridCoordsGemm1.n_block, tid},
          rock::StoreMethod::Set, forceUnroll,
          /*useIndexDiffs=*/true);
    }

    rewriter.eraseOp(op);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// GridwiseGemmAccel lowering.
//===----------------------------------------------------------------------===//

struct GridwiseGemmAccelRewritePattern
    : public OpRewritePattern<GridwiseGemmAccelOp> {
  using OpRewritePattern<GridwiseGemmAccelOp>::OpRewritePattern;

  // Generate the Read loop from LDS.  So we read A[0:mRepeats,
  // 0:kBasePerThread] and B[0:nRepeats, 0:kBasePerThread] before entering the
  // MMA loop
  void generateReadLoop(
      Location loc, PatternRewriter &b,
      const std::unique_ptr<rock::accel::AccelEmitter> &accelEmitterPtr,
      Value tid, Value ldsAView, Value ldsBView, Value &regsA, Value &regsB,
      Value regsC, int64_t blockSize, int64_t inMPerThread,
      int64_t inNPerThread, bool rotateMWithK, bool rotateNWithK) const {

    // wrapLDSBufferForLoad is reading a single set of Ks into private memory
    // A/B[m/n, 0:kBasePerThread]
    Value ldsA = accelEmitterPtr->wrapLDSBufferForLoad(
        b, loc, ldsAView, blockSize, inMPerThread, "m", rotateMWithK);

    Value ldsB = accelEmitterPtr->wrapLDSBufferForLoad(
        b, loc, ldsBView, blockSize, inNPerThread, "n", rotateNWithK);

    // We enhance the transformation from wrapLDSBufferForLoad using a builder
    // that, given a single index, splits it into "m"("n") and "k" and lets
    // tid pass through. We can give those indices to wrapLDSBufferForLoad which
    // should compute the right transform

    // Read from LDS buffer for A
    {
      ArrayRef<int64_t> ldsAShape = cast<ShapedType>(ldsA.getType()).getShape();
      assert(ldsAShape.size() == 3);
      assert(ldsAShape[0] == blockSize);
      TopDownTMBuilder mkBuilder(b, {"tid", "mk"},
                                 {blockSize, ldsAShape[1] * ldsAShape[2]}, loc);
      mkBuilder.passThrough("tid");
      mkBuilder.merge({"m", "k"}, {1, 2}, "mk", {ldsAShape[1], ldsAShape[2]});
      ldsA = rock::transform(b, ldsA, b.getArrayAttr({mkBuilder.get()}));
      ThreadwiseReadIntoOp::create(b, loc, ldsA, regsA, b.getArrayAttr({}),
                                   ValueRange{tid}, /*forceUnroll=*/true,
                                   /*useIndexDiffs=*/true);
    }

    // Read from LDS buffer for B
    {
      ArrayRef<int64_t> ldsBShape = cast<ShapedType>(ldsB.getType()).getShape();
      assert(ldsBShape.size() == 3);
      assert(ldsBShape[0] == blockSize);
      TopDownTMBuilder nkBuilder(b, {"tid", "nk"},
                                 {blockSize, ldsBShape[1] * ldsBShape[2]}, loc);
      nkBuilder.passThrough("tid");
      nkBuilder.merge({"n", "k"}, {1, 2}, "nk", {ldsBShape[1], ldsBShape[2]});
      ldsB = rock::transform(b, ldsB, b.getArrayAttr({nkBuilder.get()}));
      ThreadwiseReadIntoOp::create(b, loc, ldsB, regsB, b.getArrayAttr({}),
                                   ValueRange{tid}, /*forceUnroll=*/true,
                                   /*useIndexDiffs=*/true);
    }
  }

  LogicalResult matchAndRewrite(GridwiseGemmAccelOp op,
                                PatternRewriter &b) const override {
    Location loc = op.getLoc();

    // Obtain data types of inputs.
    auto elementTypeA = op.getA().getType().getElementType();
    auto maybeElementTypeALoad = getInputFusionElementType(op.getA());
    auto elementTypeALoad = failed(maybeElementTypeALoad)
                                ? elementTypeA
                                : maybeElementTypeALoad.value();

    auto elementTypeB = op.getB().getType().getElementType();
    auto maybeElementTypeBLoad = getInputFusionElementType(op.getB());
    auto elementTypeBLoad = failed(maybeElementTypeBLoad)
                                ? elementTypeB
                                : maybeElementTypeBLoad.value();
    auto destType = op.getC().getType().getElementType();

    // Get 'features' from the op
    auto features = rock::getFeatures(op);
    auto featuresAttr = op.getFeaturesAttr();

    // Prepare some useful constants.
    Value matA = op.getA();
    Value matB = op.getB();

    // Obtain critical matrix dimensions.
    ArrayRef<int64_t> aShape, bShape, cShape;
    aShape = op.getA().getType().getShape();
    bShape = op.getB().getType().getShape();
    cShape = op.getC().getType().getShape();
    // Obtain critical matrix dimensions.
    int64_t G = aShape[0];
    int64_t K = aShape[1];
    int64_t M = aShape[2];
    int64_t N = bShape[2];

    // Obtain critical tuning parameters.
    StringRef arch = rock::getArchValue(op);
    uint32_t blockSize = op.getBlockSize();
    uint32_t gridSize = op.getGridSize();
    RockAccelTuningParamAttrInterface tuningParams = op.getParams();
    int64_t kpack = tuningParams.getKpack();
    // TODO: kPerBlock, as defined in parameter selection etc,
    // is in units of kPack, not individual k. This should be changed
    // at some future point, but it'll be worked around for now.
    int64_t kpacksPerBlock = tuningParams.getKpackPerBlock();
    int64_t mPerBlock = tuningParams.getMPerBlock();
    int64_t nPerBlock = tuningParams.getNPerBlock();
    int64_t mBlocks = M / mPerBlock;
    int64_t nBlocks = N / nPerBlock;
    bool forceUnroll = tuningParams.getForceUnroll();
    int64_t kPerBlock = kpacksPerBlock * kpack;

    if (!isValidBlockSize(blockSize, kPerBlock, mPerBlock, nPerBlock)) {
      return emitError(loc) << "Block size too large, rejecting as invalid.\n";
    }

    int64_t aCopyPerThread = (kPerBlock * mPerBlock) / blockSize;
    int64_t bCopyPerThread = (kPerBlock * nPerBlock) / blockSize;

    int64_t aCopyKpacksPerThread =
        math_util::integer_divide_ceil(aCopyPerThread, kpack);
    int64_t bCopyKpacksPerThread =
        math_util::integer_divide_ceil(bCopyPerThread, kpack);

    // Get the vector copy layout for A and B
    FailureOr<VectorDimInfo> maybeVecDimInfoA = getVectorDim(
        b, loc, matA, elementTypeALoad, blockSize, kPerBlock, mPerBlock, kpack);
    if (failed(maybeVecDimInfoA)) {
      return failure();
    }
    FailureOr<VectorDimInfo> maybeVecDimInfoB = getVectorDim(
        b, loc, matB, elementTypeBLoad, blockSize, kPerBlock, nPerBlock, kpack);
    if (failed(maybeVecDimInfoB)) {
      return failure();
    }
    auto copyMPerThread = maybeVecDimInfoA->inDPerThread;
    auto copyNPerThread = maybeVecDimInfoB->inDPerThread;
    LLVM_DEBUG(llvm::dbgs()
               << "gridSize: " << gridSize << "\n"
               << "blockSize: " << blockSize << "\n"
               << "elementTypeALoad: " << elementTypeALoad << "\n"
               << "elementTypeBLoad: " << elementTypeBLoad << "\n"
               << "aCopyPerThread: " << aCopyPerThread << "\n"
               << "bCopyPerThread: " << bCopyPerThread << "\n"
               << "aCopyKpacksPerThread: " << aCopyKpacksPerThread << "\n"
               << "bCopyKpacksPerThread: " << bCopyKpacksPerThread << "\n"
               << "aVectorDim: " << maybeVecDimInfoA->vectorDim << "\n"
               << "aVectorLen: " << maybeVecDimInfoA->vectorLen << "\n"
               << "bVectorDim: " << maybeVecDimInfoB->vectorDim << "\n"
               << "bVectorLen: " << maybeVecDimInfoB->vectorLen << "\n"
               << "vectorTiebreaker: " << maybeVecDimInfoA->vectorTiebreaker
               << "\n"
               << "kPerBlock: " << kPerBlock << "\n"
               << "mPerBlock: " << mPerBlock << "\n"
               << "nPerBlock: " << nPerBlock << "\n"
               << "aCopyKPerThread: " << maybeVecDimInfoA->inKPerThread << "\n"
               << "bCopyKPerThread: " << maybeVecDimInfoB->inKPerThread << "\n"
               << "copyMPerThread: " << copyMPerThread << "\n"
               << "copyNPerThread: " << copyNPerThread << "\n");
    SmallVector<int64_t, 3> bidGridLengths = {G, mBlocks, nBlocks};
    SmallVector<StringRef, 3> bidGridOrder = {"g_block", "m_block", "n_block"};
    FailureOr<RegsAsMatrixSubTiles> maybeABufferViews = getLoadRegsAsTileViews(
        b, loc, matA, "m", bidGridOrder, bidGridLengths, blockSize, kPerBlock,
        mPerBlock, maybeVecDimInfoA->inKPerThread,
        maybeVecDimInfoA->inDPerThread,
        maybeVecDimInfoA->vectorDim == GemmDimension::K);
    if (failed(maybeABufferViews)) {
      return failure();
    }
    Value wrappedA = transform(b, matA, maybeABufferViews->gridSubTile);
    FailureOr<RegsAsMatrixSubTiles> maybeBBufferViews = getLoadRegsAsTileViews(
        b, loc, matB, "n", bidGridOrder, bidGridLengths, blockSize, kPerBlock,
        nPerBlock, maybeVecDimInfoB->inKPerThread,
        maybeVecDimInfoB->inDPerThread,
        maybeVecDimInfoB->vectorDim == GemmDimension::K);
    if (failed(maybeBBufferViews)) {
      return failure();
    }
    Value wrappedB = transform(b, matB, maybeBBufferViews->gridSubTile);

    // Get current workgroup ID.
    auto bid = WorkgroupIdOp::create(b, loc, b.getIndexType());
    // Get current workitem ID.
    auto tid = WorkitemIdOp::create(b, loc, b.getIndexType());

    Value loadBufferA =
        gpuAlloc(b, loc, aCopyPerThread, elementTypeA, AddressSpace::Private);
    Value loadBufferB =
        gpuAlloc(b, loc, bCopyPerThread, elementTypeB, AddressSpace::Private);

    auto zeroConstantOp = ConstantIndexOp::create(b, loc, 0);
    // Compute grid coordinates
    auto gridCoords = layout::makeGroupedGridLayout(
        b, loc, bid,
        {G, mBlocks, nBlocks, rock::getNumCUValue(op), elementTypeA, destType},
        arch);

    Value storeBufferA =
        gpuAlloc(b, loc, aCopyPerThread, elementTypeA, AddressSpace::Private);
    Value storeBufferB =
        gpuAlloc(b, loc, bCopyPerThread, elementTypeB, AddressSpace::Private);

    bool isKContiguousDimA = maybeVecDimInfoA->vectorDim == GemmDimension::K;
    bool isKContiguousDimB = maybeVecDimInfoB->vectorDim == GemmDimension::K;
    LDSLayoutConfigDim ldsLayoutConfigA =
        getLDSLayoutConfigDim(elementTypeA, kpack, maybeVecDimInfoA.value());
    LDSLayoutConfigDim ldsLayoutConfigB =
        getLDSLayoutConfigDim(elementTypeB, kpack, maybeVecDimInfoB.value());

    // We invert the transforms that are iter --> K x D slice of the tensor
    // so that we can view loadBuffer as a K x D tensor
    ArrayAttr loadBufferAViews =
        invertTransforms(b, loc, maybeABufferViews->threadSubTile);
    Value viewLoadBufferA = transform(b, loadBufferA, loadBufferAViews);
    // Prior to LDS store, we need re-arrange register buffer to maxmize LDS
    // vectorization Hence, creating the view w.r.t global that correspond to
    // such re-arranged register buffer
    FailureOr<RegsAsMatrixSubTiles> maybeALdsStoreViews =
        getPackedRegsAsTileViews(
            b, loc, matA, "m", bidGridOrder, bidGridLengths, blockSize,
            kPerBlock, mPerBlock, maybeVecDimInfoA->inKPerThread,
            maybeVecDimInfoA->inDPerThread, kpack, isKContiguousDimA,
            ldsLayoutConfigA.doSwapThreadIterSubDims);
    if (failed(maybeALdsStoreViews)) {
      return failure();
    }
    ArrayAttr storeBufferAViews =
        invertTransforms(b, loc, maybeALdsStoreViews->threadSubTile);
    Value viewStoreBufferA = transform(b, storeBufferA, storeBufferAViews);
    ArrayAttr loadBufferBViews =
        invertTransforms(b, loc, maybeBBufferViews->threadSubTile);
    Value viewLoadBufferB = transform(b, loadBufferB, loadBufferBViews);
    // Prior to LDS store, we need re-arrange register buffer to maxmize LDS
    // vectorization Hence, creating the view w.r.t global that correspond to
    // such re-arranged register buffer
    FailureOr<RegsAsMatrixSubTiles> maybeBLdsStoreViews =
        getPackedRegsAsTileViews(
            b, loc, matB, "n", bidGridOrder, bidGridLengths, blockSize,
            kPerBlock, nPerBlock, maybeVecDimInfoB->inKPerThread,
            maybeVecDimInfoB->inDPerThread, kpack, isKContiguousDimB,
            ldsLayoutConfigB.doSwapThreadIterSubDims);
    if (failed(maybeBLdsStoreViews)) {
      return failure();
    }
    ArrayAttr storeBufferBViews =
        invertTransforms(b, loc, maybeBLdsStoreViews->threadSubTile);
    Value viewStoreBufferB = transform(b, storeBufferB, storeBufferBViews);
    // Obtain Accelerator-related attributes.
    int64_t mPerWave = tuningParams.getMPerWave();
    int64_t nPerWave = tuningParams.getNPerWave();

    auto accelEmitterPtr = accel::AccelEmitter::select(
        features, elementTypeA, elementTypeB, arch, tuningParams);

    if (!accelEmitterPtr)
      return op.emitOpError("Unable to emit accelerator code.");

    // Extract relevant accelerator parameters
    rock::accel::AccelEmitterParams params = accelEmitterPtr->getParams();
    int64_t nResultVectors = params.nResultVectors;
    int64_t mRepeats = params.mRepeats;
    int64_t nRepeats = params.nRepeats;
    int64_t kBasePerThread = params.kBasePerThread;
    Type argTypeA = params.argTypeA;
    Type argTypeB = params.argTypeB;
    VectorType accVectorType = params.accVectorType;
    int64_t numOutputVectorElements = params.numOutputVectorElements();
    bool useIndexDiffs = true;

    LLVM_DEBUG(llvm::dbgs()
               << "M: " << M << "\n"
               << "N: " << N << "\n"
               << "K: " << K << "\n"
               << "G: " << G << "\n"
               << "mPerBlock: " << mPerBlock << "\n"
               << "nPerBlock: " << nPerBlock << "\n"
               << "kPerBlock: " << kPerBlock << "\n"
               << "kpack: " << kpack << "\n"
               << "mBlocks = M / mPerBlock: " << mBlocks << "\n"
               << "nBlocks = N / nPerBlock: " << nBlocks << "\n"
               << "mPerWave: " << mPerWave << "\n"
               << "nPerWave: " << nPerWave << "\n"
               << "aVectorLen: " << maybeVecDimInfoA->vectorLen << "\n"
               << "bVectorLen: " << maybeVecDimInfoB->vectorLen << "\n"
               << "aVectorDim: " << maybeVecDimInfoA->vectorDim << "\n"
               << "bVectorDim: " << maybeVecDimInfoB->vectorDim << "\n");

    // Alocate LDS and create subviews.

    // Compute required LDS sizes.
    int64_t ldsBlockASize =
        kpacksPerBlock * mPerBlock * kpack * getByteWidth(elementTypeA);
    int64_t ldsBlockBSize =
        kpacksPerBlock * nPerBlock * kpack * getByteWidth(elementTypeB);
    LLVM_DEBUG(llvm::dbgs() << "LDS block sizes (bytes): " << ldsBlockASize
                            << " " << ldsBlockBSize << "\n");
    if (failed(checkLDSSize(op, ldsBlockASize, ldsBlockBSize)))
      return op.emitOpError("requires too much LDS");

    // Allocate LDS.
    auto workgroupMemoryAddressSpace = b.getAttr<gpu::AddressSpaceAttr>(
        gpu::GPUDialect::getWorkgroupAddressSpace());
    auto ldsMemRefAType =
        MemRefType::get({ldsBlockASize}, b.getI8Type(), AffineMap{},
                        workgroupMemoryAddressSpace);
    auto ldsByteBufferA = GpuAllocOp::create(b, loc, ldsMemRefAType);
    auto ldsMemRefBType =
        MemRefType::get({ldsBlockBSize}, b.getI8Type(), AffineMap{},
                        workgroupMemoryAddressSpace);
    auto ldsByteBufferB = GpuAllocOp::create(b, loc, ldsMemRefBType);

    Type ldsReadTypeA = vectorTypeOrSelf(elementTypeA, kpack);
    FailureOr<Value> maybeWrappedLdsA = wrapLDSBufferForStore(
        b, loc, ldsByteBufferA, ldsReadTypeA, kpacksPerBlock, "m", mPerBlock,
        maybeVecDimInfoA->inKPerThread, maybeVecDimInfoA->inDPerThread,
        ldsLayoutConfigA.doRotateWithK);
    if (failed(maybeWrappedLdsA))
      return maybeWrappedLdsA;
    // This is KxD view of the flat LDS buffer
    Value wrappedLdsA = std::move(*maybeWrappedLdsA);
    // This will produce a (tid, iter) --> flat LDS view
    wrappedLdsA = transform(b, wrappedLdsA, maybeALdsStoreViews->blockSubTile);

    Type ldsReadTypeB = vectorTypeOrSelf(elementTypeB, kpack);
    FailureOr<Value> maybeWrappedLdsB = wrapLDSBufferForStore(
        b, loc, ldsByteBufferB, ldsReadTypeB, kpacksPerBlock, "n", nPerBlock,
        maybeVecDimInfoB->inKPerThread, maybeVecDimInfoB->inDPerThread,
        ldsLayoutConfigB.doRotateWithK);
    if (failed(maybeWrappedLdsB))
      return maybeWrappedLdsB;
    // This is KxD view of the flat LDS buffer
    Value wrappedLdsB = std::move(*maybeWrappedLdsB);
    // This will produce a (tid, iter) --> flat LDS view
    wrappedLdsB = transform(b, wrappedLdsB, maybeBLdsStoreViews->blockSubTile);

    Value ldsViewForGemmA = viewBufferAs(b, ldsByteBufferA, ldsReadTypeA);
    Value ldsViewForGemmB = viewBufferAs(b, ldsByteBufferB, ldsReadTypeB);
    int64_t nOutputVectors = nResultVectors * mRepeats * nRepeats;

    // TODO: add an heuristic to decide if the it should use scheduleV1 or V2.
    int64_t scheduleVersion = tuningParams.getScheduleVersion();
    int64_t initiationInterval;

    // Logic to setup buffers for blockwise_gemm_accel.
    int64_t arrayALen = kBasePerThread;
    int64_t arrayBLen = kBasePerThread;
    if (scheduleVersion == 2) {
      arrayALen *= mRepeats;
      arrayBLen *= nRepeats;
      initiationInterval = 1;
    } else if (scheduleVersion == 1) {
      initiationInterval = 2;
    } else {
      llvm_unreachable("unknown gemm schedule version. only "
                       "gemmScheduleVersions 1 or 2 are supported.");
    }

    auto arrayA = gpuAlloc(b, loc, arrayALen, argTypeA, AddressSpace::Private);
    auto arrayB = gpuAlloc(b, loc, arrayBLen, argTypeB, AddressSpace::Private);
    auto regCAllocOp =
        gpuAlloc(b, loc, nOutputVectors, accVectorType, AddressSpace::Private);

    Value zeroConstantCOp = createZeroConstantOp(b, loc, accVectorType);
    FillOp::create(b, loc, regCAllocOp, zeroConstantCOp);

    // Emit loop.
    Value nIterations = ConstantIndexOp::create(b, loc, K / kPerBlock);
    Value step = ConstantIndexOp::create(b, loc, 1);

    auto loopOp = scf::ForOp::create(b, loc, zeroConstantOp, nIterations, step);
    loopOp->setAttr(
        PipelineAttr::getMnemonic(),
        rock::PipelineAttr::get(b.getContext(), initiationInterval));
    {
      PatternRewriter::InsertionGuard guard(b);
      b.setInsertionPointToStart(loopOp.getBody());
      Value iv = loopOp.getInductionVar();
      auto stage0 = StageOp::create(b, loc, "GlobalRead");
      {
        PatternRewriter::InsertionGuard guard(b);
        b.setInsertionPointToStart(&stage0.getRegion().emplaceBlock());

        b.create<ThreadwiseReadIntoOp>(
            loc, vectorOfBoolShapedLike(loadBufferA), wrappedA, loadBufferA,
            /*dynamicValidities=*/ValueRange{},
            /*extraViews=*/b.getArrayAttr({}),
            /*extraIndices=*/
            ValueRange{/*kIter=*/iv, gridCoords.g_block, gridCoords.m_block,
                       gridCoords.n_block, tid},
            true, true);
        ThreadwiseReadIntoOp::create(
            b, loc, vectorOfBoolShapedLike(loadBufferB), wrappedB, loadBufferB,
            /*dynamicValidities=*/ValueRange{},
            /*extraViews=*/b.getArrayAttr({}),
            /*extraIndices=*/
            ValueRange{/*kIter=*/iv, gridCoords.g_block, gridCoords.m_block,
                       gridCoords.n_block, tid},
            true, true);
        rock::YieldOp::create(b, loc);
      }

      auto stage1 = StageOp::create(b, loc, "LDSWrite");
      {
        PatternRewriter::InsertionGuard guard(b);
        b.setInsertionPointToStart(&stage1.getRegion().emplaceBlock());

        // Emit potentially-transposing copies to store buffer. This is here
        // both to enable code motion for fusions and to prevent the accesses to
        // the memory from breaking software pipelining.
        ThreadwiseCopyOp::create(b, loc, viewLoadBufferA, ValueRange{},
                                 viewStoreBufferA, ValueRange{}, false, false);
        ThreadwiseCopyOp::create(b, loc, viewLoadBufferB, ValueRange{},
                                 viewStoreBufferB, ValueRange{}, false, false);
        // Emit blockwise stores
        ThreadwiseWriteAllOp::create(b, loc, storeBufferA, wrappedLdsA,
                                     /*extraViews=*/b.getArrayAttr({}),
                                     /*extraIndices=*/ValueRange{tid},
                                     StoreMethod::Set,
                                     /*forceUnroll=*/forceUnroll,
                                     /*useIndexDiffs=*/true);
        ThreadwiseWriteAllOp::create(b, loc, storeBufferB, wrappedLdsB,
                                     /*extraViews=*/b.getArrayAttr({}),
                                     /*extraIndices=*/ValueRange{tid},
                                     StoreMethod::Set,
                                     /*forceUnroll=*/forceUnroll,
                                     /*useIndexDiffs=*/true);
        rock::YieldOp::create(b, loc);
      }

      if (scheduleVersion == 1) {
        // Emit blockwise GEMM. This will load data from LDS and
        // compute the MMA at the same time
        auto stage2 = StageOp::create(b, loc, "MMA");
        {
          PatternRewriter::InsertionGuard guard(b);
          b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());

          BlockwiseGemmAccelOp::create(
              b, loc, ldsViewForGemmA, ldsViewForGemmB,
              b.getI32IntegerAttr(copyMPerThread),
              b.getI32IntegerAttr(copyNPerThread),
              (ldsLayoutConfigA.doRotateWithK ? b.getUnitAttr() : nullptr),
              (ldsLayoutConfigB.doRotateWithK ? b.getUnitAttr() : nullptr),
              /*loadAfromLDS=*/b.getUnitAttr(),
              /*loadBfromLDS=*/b.getUnitAttr(),
              /*splitKAcrossThreadsFirstA=*/nullptr,
              /*splitKAcrossThreadsFirstB=*/nullptr, arrayA, arrayB,
              regCAllocOp, featuresAttr, op.getBlockSizeAttr(),
              op.getParamsAttr());
          rock::YieldOp::create(b, loc);
        }
      } else {
        // If we are running double-buffered pipelines, it makes sense to also
        // parallelize The LDSRead/MMA stages. We do this here, by splitting the
        // MMA loop in two separate stages
        auto stage2 = StageOp::create(b, loc, "LDSRead");
        {
          // Read from LDS into registers
          PatternRewriter::InsertionGuard guard(b);
          b.setInsertionPointToStart(&stage2.getRegion().emplaceBlock());
          generateReadLoop(loc, b, accelEmitterPtr, tid, ldsViewForGemmA,
                           ldsViewForGemmB, arrayA, arrayB, regCAllocOp,
                           blockSize, copyMPerThread, copyNPerThread,
                           ldsLayoutConfigA.doRotateWithK,
                           ldsLayoutConfigB.doRotateWithK);
          rock::YieldOp::create(b, loc);
        }
        auto stage3 = StageOp::create(b, loc, "MMA");
        {
          // Compute the matrix-multiplication
          PatternRewriter::InsertionGuard guard(b);
          b.setInsertionPointToStart(&stage3.getRegion().emplaceBlock());
          BlockwiseGemmAccelOp::create(
              b, loc, ldsViewForGemmA, ldsViewForGemmB,
              b.getI32IntegerAttr(copyMPerThread),
              b.getI32IntegerAttr(copyNPerThread),
              (ldsLayoutConfigA.doRotateWithK ? b.getUnitAttr() : nullptr),
              (ldsLayoutConfigB.doRotateWithK ? b.getUnitAttr() : nullptr),
              /*loadAfromLDS=*/nullptr, /*loadBfromLDS=*/nullptr,
              /*splitKAcrossThreadsFirstA=*/nullptr,
              /*splitKAcrossThreadsFirstB=*/nullptr, arrayA, arrayB,
              regCAllocOp, featuresAttr, op.getBlockSizeAttr(),
              op.getParamsAttr());
          rock::YieldOp::create(b, loc);
        }
      }
    }

    // the LDS allocated to load A and B matrices won't be used anymore
    GpuDeallocOp::create(b, loc, ldsByteBufferA);
    GpuDeallocOp::create(b, loc, ldsByteBufferB);

    // Matrix C write out logic.
    Value convertedC = gpuAlloc(b, loc, numOutputVectorElements, destType,
                                AddressSpace::Private);

    FailureOr<RegsAsMatrixSubTiles> maybeIdToMatrixCMaps =
        accelEmitterPtr->computeOutputTransforms(
            b, loc, M, N, blockSize, bidGridLengths,
            maybeVecDimInfoA->inDPerThread, maybeVecDimInfoB->inDPerThread,
            ldsLayoutConfigA.doSwapThreadIterSubDims,
            ldsLayoutConfigB.doSwapThreadIterSubDims);
    if (failed(maybeIdToMatrixCMaps)) {
      return failure();
    }
    ArrayAttr idToMatrixCMaps = maybeIdToMatrixCMaps.value().gridSubTile;

    accelEmitterPtr->computeOutputConversion(b, loc, regCAllocOp, convertedC,
                                             forceUnroll);

    ThreadwiseWriteAllOp::create(
        b, loc, convertedC, op.getC(), idToMatrixCMaps,
        /*extraIndices=*/
        ValueRange{gridCoords.g_block, gridCoords.m_block, gridCoords.n_block,
                   tid},
        op.getStoreMethod(), forceUnroll, useIndexDiffs);
    b.eraseOp(op);
    return success();
  }
};

} // end anonymous namespace

void RockGridwiseGemmToBlockwisePass::runOnOperation() {
  MLIRContext *ctx = &getContext();
  ConversionTarget target(*ctx);
  target.addIllegalOp<rock::GridwiseGemmOp, rock::GridwiseGemmAccelOp,
                      GridwiseAttentionAccelOp>();
  target.addLegalDialect<arith::ArithDialect, rock::RockDialect,
                         memref::MemRefDialect, affine::AffineDialect,
                         vector::VectorDialect, linalg::LinalgDialect,
                         scf::SCFDialect, math::MathDialect>();
  target.addLegalOp<gpu::PrintfOp>();

  RewritePatternSet patterns(ctx);
  patterns.add<GridwiseGemmRewritePattern, GridwiseGemmAccelRewritePattern,
               GridwiseAttentionAccelRewritePattern>(ctx);
  if (failed(applyPartialConversion(getOperation(), target,
                                    std::move(patterns)))) {
    signalPassFailure();
  }
}
