blob: dd204126be5dbce078baadb0d450b2895e1333bc [file] [log] [blame]
//===-- CUFAddConstructor.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/Builder/BoxValue.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/CodeGen/Target.h"
#include "flang/Optimizer/CodeGen/TypeConverter.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/FIRAttr.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Support/DataLayout.h"
#include "flang/Optimizer/Transforms/CUFCommon.h"
#include "flang/Runtime/CUDA/registration.h"
#include "flang/Runtime/entry-names.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/ADT/SmallVector.h"
namespace fir {
#define GEN_PASS_DEF_CUFADDCONSTRUCTOR
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir
using namespace Fortran::runtime::cuda;
namespace {
static constexpr llvm::StringRef cudaFortranCtorName{
"__cudaFortranConstructor"};
struct CUFAddConstructor
: public fir::impl::CUFAddConstructorBase<CUFAddConstructor> {
void runOnOperation() override {
mlir::ModuleOp mod = getOperation();
mlir::SymbolTable symTab(mod);
mlir::OpBuilder opBuilder{mod.getBodyRegion()};
fir::FirOpBuilder builder(opBuilder, mod);
fir::KindMapping kindMap{fir::getKindMapping(mod)};
builder.setInsertionPointToEnd(mod.getBody());
mlir::Location loc = mod.getLoc();
auto *ctx = mod.getContext();
auto voidTy = mlir::LLVM::LLVMVoidType::get(ctx);
auto idxTy = builder.getIndexType();
auto funcTy =
mlir::LLVM::LLVMFunctionType::get(voidTy, {}, /*isVarArg=*/false);
std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(mod, /*allowDefaultLayout=*/false);
if (!dl) {
mlir::emitError(mod.getLoc(),
"data layout attribute is required to perform " +
getName() + "pass");
}
// Symbol reference to CUFRegisterAllocator.
builder.setInsertionPointToEnd(mod.getBody());
auto registerFuncOp = builder.create<mlir::LLVM::LLVMFuncOp>(
loc, RTNAME_STRING(CUFRegisterAllocator), funcTy);
registerFuncOp.setVisibility(mlir::SymbolTable::Visibility::Private);
auto cufRegisterAllocatorRef = mlir::SymbolRefAttr::get(
mod.getContext(), RTNAME_STRING(CUFRegisterAllocator));
builder.setInsertionPointToEnd(mod.getBody());
// Create the constructor function that call CUFRegisterAllocator.
auto func = builder.create<mlir::LLVM::LLVMFuncOp>(loc, cudaFortranCtorName,
funcTy);
func.setLinkage(mlir::LLVM::Linkage::Internal);
builder.setInsertionPointToStart(func.addEntryBlock(builder));
builder.create<mlir::LLVM::CallOp>(loc, funcTy, cufRegisterAllocatorRef);
auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName);
if (gpuMod) {
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(ctx);
auto registeredMod = builder.create<cuf::RegisterModuleOp>(
loc, llvmPtrTy, mlir::SymbolRefAttr::get(ctx, gpuMod.getName()));
fir::LLVMTypeConverter typeConverter(mod, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
// Register kernels
for (auto func : gpuMod.getOps<mlir::gpu::GPUFuncOp>()) {
if (func.isKernel()) {
auto kernelName = mlir::SymbolRefAttr::get(
builder.getStringAttr(cudaDeviceModuleName),
{mlir::SymbolRefAttr::get(builder.getContext(), func.getName())});
builder.create<cuf::RegisterKernelOp>(loc, kernelName, registeredMod);
}
}
// Register variables
for (fir::GlobalOp globalOp : mod.getOps<fir::GlobalOp>()) {
auto attr = globalOp.getDataAttrAttr();
if (!attr)
continue;
mlir::func::FuncOp func;
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Constant: {
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>(
loc, builder);
auto fTy = func.getFunctionType();
// Global variable name
std::string gblNameStr = globalOp.getSymbol().getValue().str();
gblNameStr += '\0';
mlir::Value gblName = fir::getBase(
fir::factory::createStringLiteral(builder, loc, gblNameStr));
// Global variable size
std::optional<uint64_t> size;
if (auto boxTy =
mlir::dyn_cast<fir::BaseBoxType>(globalOp.getType())) {
mlir::Type structTy = typeConverter.convertBoxTypeAsStruct(boxTy);
size = dl->getTypeSizeInBits(structTy) / 8;
}
if (!size) {
size = fir::getTypeSizeAndAlignmentOrCrash(loc, globalOp.getType(),
*dl, kindMap)
.first;
}
auto sizeVal = builder.createIntegerConstant(loc, idxTy, *size);
// Global variable address
mlir::Value addr = builder.create<fir::AddrOfOp>(
loc, globalOp.resultType(), globalOp.getSymbol());
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, registeredMod, addr, gblName, sizeVal)};
builder.create<fir::CallOp>(loc, func, args);
} break;
case cuf::DataAttribute::Managed:
TODO(loc, "registration of managed variables");
default:
break;
}
if (!func)
continue;
}
}
builder.create<mlir::LLVM::ReturnOp>(loc, mlir::ValueRange{});
// Create the llvm.global_ctor with the function.
// TODO: We might want to have a utility that retrieve it if already
// created and adds new functions.
builder.setInsertionPointToEnd(mod.getBody());
llvm::SmallVector<mlir::Attribute> funcs;
funcs.push_back(
mlir::FlatSymbolRefAttr::get(mod.getContext(), func.getSymName()));
llvm::SmallVector<int> priorities;
priorities.push_back(0);
builder.create<mlir::LLVM::GlobalCtorsOp>(
mod.getLoc(), builder.getArrayAttr(funcs),
builder.getI32ArrayAttr(priorities));
}
};
} // end anonymous namespace