blob: a9503857090ab8e179f78f8238dfa7b6b6aa8b21 [file]
//===- IndexedAccessOpInterfaceImpl.cpp -----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Implement IndexedAccessOpInterface on GPU dialect operations that have
// %memref[%i0, %i1, ...] arguments to allow them to be manipulated by
// generic memref-dialect passes.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/Transforms/IndexedAccessOpInterfaceImpl.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::memref;
using namespace mlir::gpu;
/// Given a GPU matrix type that will be loaded or stored, the leading dimension
/// of the matrix in memory, and whether or not the matrix is transposed,
/// compute the size of the linear memory that the load/store spans as
/// dC + leadingDim * (dR - 1) where dR and dC are the non-contiguous and
/// contiguous matrix dimensions, respectively (we get to the dX-1th row and
/// then access the first dY elements of it).
static int64_t get1DAccessSize(MMAMatrixType matrixType, int64_t leadingDim,
bool transpose) {
assert(matrixType.getShape().size() == 2 && "expected matrices to be 2D");
int64_t c = matrixType.getShape()[1];
int64_t r = matrixType.getShape()[0];
if (transpose)
std::swap(c, r);
return c + leadingDim * (r - 1);
}
namespace {
struct SubgroupMmaLoadMatrixOpImpl final
: IndexedAccessOpInterface::ExternalModel<SubgroupMmaLoadMatrixOpImpl,
SubgroupMmaLoadMatrixOp> {
TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
return cast<SubgroupMmaLoadMatrixOp>(op).getSrcMemref();
}
Operation::operand_range getIndices(Operation *op) const {
return cast<SubgroupMmaLoadMatrixOp>(op).getIndices();
}
/// This returns a 1-D shape so that it's clear that both linearization and
/// folding in expand/collapse_shape operations are allowed.
SmallVector<int64_t> getAccessedShape(Operation *op) const {
auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
return {get1DAccessSize(cast<MMAMatrixType>(loadOp.getRes().getType()),
loadOp.getLeadDimension().getZExtValue(),
loadOp.getTranspose().value_or(false))};
}
std::optional<SmallVector<Value>>
updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
ValueRange newIndices) const {
auto loadOp = cast<SubgroupMmaLoadMatrixOp>(op);
rewriter.modifyOpInPlace(loadOp, [&]() {
loadOp.getSrcMemrefMutable().assign(newMemref);
loadOp.getIndicesMutable().assign(newIndices);
});
return std::nullopt;
}
bool hasInboundsIndices(Operation *) const { return true; }
};
struct SubgroupMmaStoreMatrixOpImpl final
: IndexedAccessOpInterface::ExternalModel<SubgroupMmaStoreMatrixOpImpl,
SubgroupMmaStoreMatrixOp> {
TypedValue<MemRefType> getAccessedMemref(Operation *op) const {
return cast<SubgroupMmaStoreMatrixOp>(op).getDstMemref();
}
Operation::operand_range getIndices(Operation *op) const {
return cast<SubgroupMmaStoreMatrixOp>(op).getIndices();
}
/// This returns a 1-D shape so that it's clear that both linearization and
/// folding in expand/collapse_shape operations are allowed.
SmallVector<int64_t> getAccessedShape(Operation *op) const {
auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
return {get1DAccessSize(storeOp.getSrc().getType(),
storeOp.getLeadDimension().getZExtValue(),
storeOp.getTranspose().value_or(false))};
}
std::optional<SmallVector<Value>>
updateMemrefAndIndices(Operation *op, RewriterBase &rewriter, Value newMemref,
ValueRange newIndices) const {
auto storeOp = cast<SubgroupMmaStoreMatrixOp>(op);
rewriter.modifyOpInPlace(storeOp, [&]() {
storeOp.getDstMemrefMutable().assign(newMemref);
storeOp.getIndicesMutable().assign(newIndices);
});
return std::nullopt;
}
bool hasInboundsIndices(Operation *) const { return true; }
};
} // namespace
void mlir::gpu::registerIndexedAccessOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
SubgroupMmaLoadMatrixOp::attachInterface<SubgroupMmaLoadMatrixOpImpl>(*ctx);
SubgroupMmaStoreMatrixOp::attachInterface<SubgroupMmaStoreMatrixOpImpl>(
*ctx);
});
}