blob: 4050064ebe95d84ba55563fd8f398bd1b748cc79 [file] [log] [blame]
//===-- CUFDeviceGlobal.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Transforms/CUFOpConversion.h"
#include "flang/Common/Fortran.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/allocatable.h"
#include "flang/Runtime/CUDA/common.h"
#include "flang/Runtime/CUDA/descriptor.h"
#include "flang/Runtime/CUDA/memory.h"
#include "flang/Runtime/allocatable.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace fir {
#define GEN_PASS_DEF_CUFOPCONVERSION
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
using namespace fir;
using namespace mlir;
using namespace Fortran::runtime;
using namespace Fortran::runtime::cuda;
namespace {
static inline unsigned getMemType(cuf::DataAttribute attr) {
if (attr == cuf::DataAttribute::Device)
return kMemTypeDevice;
if (attr == cuf::DataAttribute::Managed)
return kMemTypeManaged;
if (attr == cuf::DataAttribute::Unified)
return kMemTypeUnified;
if (attr == cuf::DataAttribute::Pinned)
return kMemTypePinned;
llvm::report_fatal_error("unsupported memory type");
}
template <typename OpTy>
static bool isPinned(OpTy op) {
if (op.getDataAttr() && *op.getDataAttr() == cuf::DataAttribute::Pinned)
return true;
return false;
}
template <typename OpTy>
static bool hasDoubleDescriptors(OpTy op) {
if (auto declareOp =
mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
declareOp.getMemref().getDefiningOp())) {
if (isPinned(declareOp))
return false;
return true;
}
} else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
op.getBox().getDefiningOp())) {
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
declareOp.getMemref().getDefiningOp())) {
if (isPinned(declareOp))
return false;
return true;
}
}
return false;
}
static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
if (val.getType() != toTy)
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
return val;
}
mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
mlir::OpOperand &operand,
const mlir::SymbolTable &symtab) {
mlir::Value v = operand.get();
auto declareOp = v.getDefiningOp<fir::DeclareOp>();
if (!declareOp)
return v;
auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
if (!addrOfOp)
return v;
auto globalOp = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue());
if (!globalOp)
return v;
bool isDevGlobal{false};
auto attr = globalOp.getDataAttrAttr();
if (attr) {
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Managed:
case cuf::DataAttribute::Constant:
isDevGlobal = true;
break;
default:
break;
}
}
if (!isDevGlobal)
return v;
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(operand.getOwner());
auto loc = declareOp.getLoc();
auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
auto fTy = callee.getFunctionType();
auto toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, declareOp.getResult());
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
return call->getResult(0);
}
template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
mlir::func::FuncOp func) {
auto mod = op->template getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
mlir::Value hasStat = op.getHasStat() ? builder.createBool(loc, true)
: builder.createBool(loc, false);
mlir::Value errmsg;
if (op.getErrmsg()) {
errmsg = op.getErrmsg();
} else {
mlir::Type boxNoneTy = fir::BoxType::get(builder.getNoneType());
errmsg = builder.create<fir::AbsentOp>(loc, boxNoneTy).getResult();
}
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, op.getBox(), hasStat, errmsg, sourceFile, sourceLine)};
auto callOp = builder.create<fir::CallOp>(loc, func, args);
rewriter.replaceOp(op, callOp);
return mlir::success();
}
struct CUFAllocateOpConversion
: public mlir::OpRewritePattern<cuf::AllocateOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(cuf::AllocateOp op,
mlir::PatternRewriter &rewriter) const override {
// TODO: Allocation with source will need a new entry point in the runtime.
if (op.getSource())
return mlir::failure();
// TODO: Allocation using different stream.
if (op.getStream())
return mlir::failure();
// TODO: Pinned is a reference to a logical value that can be set to true
// when pinned allocation succeed. This will require a new entry point.
if (op.getPinned())
return mlir::failure();
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
if (hasDoubleDescriptors(op)) {
// Allocation for module variable are done with custom runtime entry point
// so the descriptors can be synchronized.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableAllocate)>(
loc, builder);
return convertOpToCall(op, rewriter, func);
}
// Allocation for local descriptor falls back on the standard runtime
// AllocatableAllocate as the dedicated allocator is set in the descriptor
// before the call.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableAllocate)>(loc,
builder);
return convertOpToCall<cuf::AllocateOp>(op, rewriter, func);
}
};
struct CUFDeallocateOpConversion
: public mlir::OpRewritePattern<cuf::DeallocateOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(cuf::DeallocateOp op,
mlir::PatternRewriter &rewriter) const override {
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
if (hasDoubleDescriptors(op)) {
// Deallocation for module variable are done with custom runtime entry
// point so the descriptors can be synchronized.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocatableDeallocate)>(
loc, builder);
return convertOpToCall(op, rewriter, func);
}
// Deallocation for local descriptor falls back on the standard runtime
// AllocatableDeallocate as the dedicated deallocator is set in the
// descriptor before the call.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(AllocatableDeallocate)>(loc,
builder);
return convertOpToCall<cuf::DeallocateOp>(op, rewriter, func);
}
};
static bool inDeviceContext(mlir::Operation *op) {
if (op->getParentOfType<cuf::KernelOp>())
return true;
if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) {
if (auto cudaProcAttr =
funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
cuf::getProcAttrName())) {
return cudaProcAttr.getValue() != cuf::ProcAttribute::Host &&
cudaProcAttr.getValue() != cuf::ProcAttribute::HostDevice;
}
}
return false;
}
static int computeWidth(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap) {
auto eleTy = fir::unwrapSequenceType(type);
int width = 0;
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)}) {
width = t.getWidth() / 8;
} else if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)}) {
width = t.getWidth() / 8;
} else if (eleTy.isInteger(1)) {
width = 1;
} else if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)}) {
int kind = t.getFKind();
width = kindMap.getLogicalBitsize(kind) / 8;
} else if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
int elemSize =
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
width = 2 * elemSize;
} else {
llvm::report_fatal_error("unsupported type");
}
return width;
}
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
CUFAllocOpConversion(mlir::MLIRContext *context, mlir::DataLayout *dl,
const fir::LLVMTypeConverter *typeConverter)
: OpRewritePattern(context), dl{dl}, typeConverter{typeConverter} {}
mlir::LogicalResult
matchAndRewrite(cuf::AllocOp op,
mlir::PatternRewriter &rewriter) const override {
if (inDeviceContext(op.getOperation())) {
// In device context just replace the cuf.alloc operation with a fir.alloc
// the cuf.free will be removed.
rewriter.replaceOpWithNewOp<fir::AllocaOp>(
op, op.getInType(), op.getUniqName() ? *op.getUniqName() : "",
op.getBindcName() ? *op.getBindcName() : "", op.getTypeparams(),
op.getShape());
return mlir::success();
}
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
if (!mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType())) {
// Convert scalar and known size array allocations.
mlir::Value bytes;
fir::KindMapping kindMap{fir::getKindMapping(mod)};
if (fir::isa_trivial(op.getInType())) {
int width = computeWidth(loc, op.getInType(), kindMap);
bytes =
builder.createIntegerConstant(loc, builder.getIndexType(), width);
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
op.getInType())) {
mlir::Value width = builder.createIntegerConstant(
loc, builder.getIndexType(),
computeWidth(loc, seqTy.getEleTy(), kindMap));
mlir::Value nbElem;
if (fir::sequenceWithNonConstantShape(seqTy)) {
assert(!op.getShape().empty() && "expect shape with dynamic arrays");
nbElem = builder.loadIfRef(loc, op.getShape()[0]);
for (unsigned i = 1; i < op.getShape().size(); ++i) {
nbElem = rewriter.create<mlir::arith::MulIOp>(
loc, nbElem, builder.loadIfRef(loc, op.getShape()[i]));
}
} else {
nbElem = builder.createIntegerConstant(loc, builder.getIndexType(),
seqTy.getConstantArraySize());
}
bytes = rewriter.create<mlir::arith::MulIOp>(loc, nbElem, width);
}
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemAlloc)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
mlir::Value memTy = builder.createIntegerConstant(
loc, builder.getI32Type(), getMemType(op.getDataAttr()));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, bytes, memTy, sourceFile, sourceLine)};
auto callOp = builder.create<fir::CallOp>(loc, func, args);
auto convOp = builder.createConvert(loc, op.getResult().getType(),
callOp.getResult(0));
rewriter.replaceOp(op, convOp);
return mlir::success();
}
// Convert descriptor allocations to function call.
auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType());
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
mlir::Type structTy = typeConverter->convertBoxTypeAsStruct(boxTy);
std::size_t boxSize = dl->getTypeSizeInBits(structTy) / 8;
mlir::Value sizeInBytes =
builder.createIntegerConstant(loc, builder.getIndexType(), boxSize);
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, sizeInBytes, sourceFile, sourceLine)};
auto callOp = builder.create<fir::CallOp>(loc, func, args);
auto convOp = builder.createConvert(loc, op.getResult().getType(),
callOp.getResult(0));
rewriter.replaceOp(op, convOp);
return mlir::success();
}
private:
mlir::DataLayout *dl;
const fir::LLVMTypeConverter *typeConverter;
};
struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(cuf::FreeOp op,
mlir::PatternRewriter &rewriter) const override {
if (inDeviceContext(op.getOperation())) {
rewriter.eraseOp(op);
return mlir::success();
}
if (!mlir::isa<fir::ReferenceType>(op.getDevptr().getType()))
return failure();
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getDevptr().getType());
if (!mlir::isa<fir::BaseBoxType>(refTy.getEleTy())) {
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemFree)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
mlir::Value memTy = builder.createIntegerConstant(
loc, builder.getI32Type(), getMemType(op.getDataAttr()));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, op.getDevptr(), memTy, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
}
// Convert cuf.free on descriptors.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, op.getDevptr(), sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
}
};
struct CUFDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;
CUFDataTransferOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symtab)
: OpRewritePattern(context), symtab{symtab} {}
mlir::LogicalResult
matchAndRewrite(cuf::DataTransferOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Type srcTy = fir::unwrapRefType(op.getSrc().getType());
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
mlir::Location loc = op.getLoc();
unsigned mode = 0;
if (op.getTransferKind() == cuf::DataTransferKind::HostDevice) {
mode = kHostToDevice;
} else if (op.getTransferKind() == cuf::DataTransferKind::DeviceHost) {
mode = kDeviceToHost;
} else if (op.getTransferKind() == cuf::DataTransferKind::DeviceDevice) {
mode = kDeviceToDevice;
} else {
mlir::emitError(loc, "unsupported transfer kind\n");
}
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
fir::KindMapping kindMap{fir::getKindMapping(mod)};
mlir::Value modeValue =
builder.createIntegerConstant(loc, builder.getI32Type(), mode);
// Convert data transfer without any descriptor.
if (!mlir::isa<fir::BaseBoxType>(srcTy) &&
!mlir::isa<fir::BaseBoxType>(dstTy)) {
if (fir::isa_trivial(srcTy) && !fir::isa_trivial(dstTy)) {
// TODO: scalar to array data transfer.
mlir::emitError(loc,
"not yet implemented: scalar to array data transfer\n");
return mlir::failure();
}
mlir::Type i64Ty = builder.getI64Type();
mlir::Value nbElement;
if (op.getShape()) {
auto shapeOp =
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
nbElement = rewriter.create<fir::ConvertOp>(loc, i64Ty,
shapeOp.getExtents()[0]);
for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
auto operand = rewriter.create<fir::ConvertOp>(
loc, i64Ty, shapeOp.getExtents()[i]);
nbElement =
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
}
} else {
if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(dstTy))
nbElement = builder.createIntegerConstant(
loc, i64Ty, seqTy.getConstantArraySize());
}
int width = computeWidth(loc, dstTy, kindMap);
mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
mlir::Value bytes =
nbElement
? rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue)
: widthValue;
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrPtr)>(loc,
builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
}
// Conversion of data transfer involving at least one descriptor.
if (mlir::isa<fir::BaseBoxType>(srcTy) &&
mlir::isa<fir::BaseBoxType>(dstTy)) {
// Transfer between two descriptor.
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescDesc)>(
loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, dst, src, modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else if (mlir::isa<fir::BaseBoxType>(dstTy) && fir::isa_trivial(srcTy)) {
// Scalar to descriptor transfer.
mlir::Value val = op.getSrc();
if (op.getSrc().getDefiningOp() &&
mlir::isa<mlir::arith::ConstantOp>(op.getSrc().getDefiningOp())) {
mlir::Value alloc = builder.createTemporary(loc, srcTy);
builder.create<fir::StoreOp>(loc, op.getSrc(), alloc);
val = alloc;
}
mlir::func::FuncOp func =
fir::runtime::getRuntimeFunc<mkRTKey(CUFMemsetDescriptor)>(loc,
builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(3));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, op.getDst(), val, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
// Type used to compute the width.
mlir::Type computeType = dstTy;
auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
bool dstIsDesc = false;
if (mlir::isa<fir::BaseBoxType>(dstTy)) {
dstIsDesc = true;
computeType = srcTy;
seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
}
int width = computeWidth(loc, computeType, kindMap);
mlir::Value nbElement;
mlir::Type idxTy = rewriter.getIndexType();
if (!op.getShape()) {
nbElement = rewriter.create<mlir::arith::ConstantOp>(
loc, idxTy,
rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
} else {
auto shapeOp =
mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
nbElement =
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
auto operand =
createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
nbElement =
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
}
}
mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
mlir::Value bytes =
rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
mlir::func::FuncOp func =
dstIsDesc
? fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferDescPtr)>(
loc, builder)
: fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
}
return mlir::success();
}
private:
const mlir::SymbolTable &symtab;
};
struct CUFLaunchOpConversion
: public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
public:
using OpRewritePattern::OpRewritePattern;
CUFLaunchOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symTab)
: OpRewritePattern(context), symTab{symTab} {}
mlir::LogicalResult
matchAndRewrite(cuf::KernelLaunchOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = op.getLoc();
auto idxTy = mlir::IndexType::get(op.getContext());
auto zero = rewriter.create<mlir::arith::ConstantOp>(
loc, rewriter.getIntegerType(32), rewriter.getI32IntegerAttr(0));
auto gridSizeX =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX());
auto gridSizeY =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY());
auto gridSizeZ =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ());
auto blockSizeX =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX());
auto blockSizeY =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY());
auto blockSizeZ =
rewriter.create<mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ());
auto kernelName = mlir::SymbolRefAttr::get(
rewriter.getStringAttr(cudaDeviceModuleName),
{mlir::SymbolRefAttr::get(
rewriter.getContext(),
op.getCallee().getLeafReference().getValue())});
mlir::Value clusterDimX, clusterDimY, clusterDimZ;
if (auto funcOp = symTab.lookup<mlir::func::FuncOp>(
op.getCallee().getLeafReference())) {
if (auto clusterDimsAttr = funcOp->getAttrOfType<cuf::ClusterDimsAttr>(
cuf::getClusterDimsAttrName())) {
clusterDimX = rewriter.create<mlir::arith::ConstantIndexOp>(
loc, clusterDimsAttr.getX().getInt());
clusterDimY = rewriter.create<mlir::arith::ConstantIndexOp>(
loc, clusterDimsAttr.getY().getInt());
clusterDimZ = rewriter.create<mlir::arith::ConstantIndexOp>(
loc, clusterDimsAttr.getZ().getInt());
}
}
auto gpuLaunchOp = rewriter.create<mlir::gpu::LaunchFuncOp>(
loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
op.getArgs());
if (clusterDimX && clusterDimY && clusterDimZ) {
gpuLaunchOp.getClusterSizeXMutable().assign(clusterDimX);
gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
}
rewriter.replaceOp(op, gpuLaunchOp);
return mlir::success();
}
private:
const mlir::SymbolTable &symTab;
};
class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
public:
void runOnOperation() override {
auto *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
mlir::ConversionTarget target(*ctx);
mlir::Operation *op = getOperation();
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
if (!module)
return signalPassFailure();
mlir::SymbolTable symtab(module);
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
mlir::gpu::GPUDialect>();
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
"error in CUF op conversion\n");
signalPassFailure();
}
}
};
} // namespace
void cuf::populateCUFToFIRConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion,
CUFFreeOpConversion>(patterns.getContext());
patterns.insert<CUFDataTransferOpConversion, CUFLaunchOpConversion>(
patterns.getContext(), symtab);
}