blob: 16618284abaca96eb33ac6cec4f1fcaf0c8ec0f2 [file] [log] [blame] [edit]
//===-- BoxedProcedure.cpp ------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/LowLevelIntrinsics.h"
#include "flang/Optimizer/Builder/Runtime/Trampoline.h"
#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
#include "flang/Optimizer/Support/FatalError.h"
#include "flang/Optimizer/Support/InternalNames.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
namespace fir {
#define GEN_PASS_DEF_BOXEDPROCEDUREPASS
#include "flang/Optimizer/CodeGen/CGPasses.h.inc"
} // namespace fir
#define DEBUG_TYPE "flang-procedure-pointer"
using namespace fir;
namespace {
/// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types.
class BoxprocTypeRewriter : public mlir::TypeConverter {
public:
using mlir::TypeConverter::convertType;
/// Does the type \p ty need to be converted?
/// Any type that is a `!fir.boxproc` in whole or in part will need to be
/// converted to a function type to lower the IR to function pointer form in
/// the default implementation performed in this pass. Other implementations
/// are possible, so those may convert `!fir.boxproc` to some other type or
/// not at all depending on the implementation target's characteristics and
/// preference.
bool needsConversion(mlir::Type ty) {
if (mlir::isa<BoxProcType>(ty))
return true;
if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) {
for (auto t : funcTy.getInputs())
if (needsConversion(t))
return true;
for (auto t : funcTy.getResults())
if (needsConversion(t))
return true;
return false;
}
if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) {
for (auto t : tupleTy.getTypes())
if (needsConversion(t))
return true;
return false;
}
if (auto recTy = mlir::dyn_cast<RecordType>(ty)) {
auto [visited, inserted] = visitedTypes.try_emplace(ty, false);
if (!inserted)
return visited->second;
bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType;
needConversionIsVisitingRecordType = true;
bool result = false;
for (auto t : recTy.getTypeList()) {
if (needsConversion(t.second)) {
result = true;
break;
}
}
// Only keep the result cached if the fir.type visited was a "top-level
// type". Nested types with a recursive reference to the "top-level type"
// may incorrectly have been resolved as not needed conversions because it
// had not been determined yet if the "top-level type" needed conversion.
// This is not an issue to determine the "top-level type" need of
// conversion, but the result should not be kept and later used in other
// contexts.
needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType;
if (needConversionIsVisitingRecordType)
visitedTypes.erase(ty);
else
visitedTypes.find(ty)->second = result;
return result;
}
if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty))
return needsConversion(boxTy.getEleTy());
if (isa_ref_type(ty))
return needsConversion(unwrapRefType(ty));
if (auto t = mlir::dyn_cast<SequenceType>(ty))
return needsConversion(unwrapSequenceType(ty));
if (auto t = mlir::dyn_cast<TypeDescType>(ty))
return needsConversion(t.getOfTy());
return false;
}
BoxprocTypeRewriter(mlir::Location location) : loc{location} {
addConversion([](mlir::Type ty) { return ty; });
addConversion(
[&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); });
addConversion([&](mlir::TupleType tupTy) {
llvm::SmallVector<mlir::Type> memTys;
for (auto ty : tupTy.getTypes())
memTys.push_back(convertType(ty));
return mlir::TupleType::get(tupTy.getContext(), memTys);
});
addConversion([&](mlir::FunctionType funcTy) {
llvm::SmallVector<mlir::Type> inTys;
llvm::SmallVector<mlir::Type> resTys;
for (auto ty : funcTy.getInputs())
inTys.push_back(convertType(ty));
for (auto ty : funcTy.getResults())
resTys.push_back(convertType(ty));
return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys);
});
addConversion([&](ReferenceType ty) {
return ReferenceType::get(convertType(ty.getEleTy()));
});
addConversion([&](PointerType ty) {
return PointerType::get(convertType(ty.getEleTy()));
});
addConversion(
[&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); });
addConversion([&](fir::LLVMPointerType ty) {
return fir::LLVMPointerType::get(convertType(ty.getEleTy()));
});
addConversion(
[&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); });
addConversion([&](ClassType ty) {
return ClassType::get(convertType(ty.getEleTy()));
});
addConversion([&](SequenceType ty) {
// TODO: add ty.getLayoutMap() as needed.
return SequenceType::get(ty.getShape(), convertType(ty.getEleTy()));
});
addConversion([&](RecordType ty) -> mlir::Type {
if (!needsConversion(ty))
return ty;
if (auto converted = convertedTypes.lookup(ty))
return converted;
auto rec = RecordType::get(ty.getContext(),
ty.getName().str() + boxprocSuffix.str());
if (rec.isFinalized())
return rec;
[[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec);
assert(it.second && "expected ty to not be in the map");
std::vector<RecordType::TypePair> ps = ty.getLenParamList();
std::vector<RecordType::TypePair> cs;
for (auto t : ty.getTypeList()) {
if (needsConversion(t.second))
cs.emplace_back(t.first, convertType(t.second));
else
cs.emplace_back(t.first, t.second);
}
rec.finalize(ps, cs);
rec.pack(ty.isPacked());
return rec;
});
addConversion([&](TypeDescType ty) {
return TypeDescType::get(convertType(ty.getOfTy()));
});
addSourceMaterialization(materializeProcedure);
addTargetMaterialization(materializeProcedure);
}
static mlir::Value materializeProcedure(mlir::OpBuilder &builder,
BoxProcType type,
mlir::ValueRange inputs,
mlir::Location loc) {
assert(inputs.size() == 1);
return ConvertOp::create(builder, loc, unwrapRefType(type.getEleTy()),
inputs[0]);
}
void setLocation(mlir::Location location) { loc = location; }
private:
// Maps to deal with recursive derived types (avoid infinite loops).
// Caching is also beneficial for apps with big types (dozens of
// components and or parent types), so the lifetime of the cache
// is the whole pass.
llvm::DenseMap<mlir::Type, bool> visitedTypes;
bool needConversionIsVisitingRecordType = false;
llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes;
mlir::Location loc;
};
/// A `boxproc` is an abstraction for a Fortran procedure reference. Typically,
/// Fortran procedures can be referenced directly through a function pointer.
/// However, Fortran has one-level dynamic scoping between a host procedure and
/// its internal procedures. This allows internal procedures to directly access
/// and modify the state of the host procedure's variables.
///
/// There are any number of possible implementations possible.
///
/// The implementation used here is to convert `boxproc` values to function
/// pointers everywhere. If a `boxproc` value includes a frame pointer to the
/// host procedure's data, then a thunk will be created at runtime to capture
/// the frame pointer during execution. In LLVM IR, the frame pointer is
/// designated with the `nest` attribute. The thunk's address will then be used
/// as the call target instead of the original function's address directly.
class BoxedProcedurePass
: public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> {
public:
using BoxedProcedurePassBase<BoxedProcedurePass>::BoxedProcedurePassBase;
inline mlir::ModuleOp getModule() { return getOperation(); }
void runOnOperation() override final {
if (useThunks) {
auto *context = &getContext();
mlir::IRRewriter rewriter(context);
BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context));
// When using safe trampolines, we need to track handles per
// function so we can insert FreeTrampoline calls at each return.
// Process functions individually to manage this state.
if (useSafeTrampoline) {
getModule().walk([&](mlir::func::FuncOp funcOp) {
trampolineHandles.clear();
trampolineCallableMap.clear();
processFunction(funcOp, rewriter, typeConverter);
insertTrampolineFrees(funcOp, rewriter);
});
// Also process non-function ops at module level (globals, etc.)
processModuleLevelOps(rewriter, typeConverter);
} else {
getModule().walk([&](mlir::Operation *op) {
processOp(op, rewriter, typeConverter);
});
}
}
}
private:
/// Trampoline handles collected while processing a function.
/// Each entry is a Value representing the opaque handle returned
/// by _FortranATrampolineInit, which must be freed before the
/// function returns.
llvm::SmallVector<mlir::Value> trampolineHandles;
/// Cache of trampoline callable addresses keyed by the func SSA value
/// of the emboxproc. This deduplicates trampolines when the same
/// internal procedure is emboxed multiple times in one host function.
llvm::DenseMap<mlir::Value, mlir::Value> trampolineCallableMap;
/// Process all ops within a function.
void processFunction(mlir::func::FuncOp funcOp, mlir::IRRewriter &rewriter,
BoxprocTypeRewriter &typeConverter) {
funcOp.walk(
[&](mlir::Operation *op) { processOp(op, rewriter, typeConverter); });
}
/// Process non-function ops at module level (globals, etc.)
void processModuleLevelOps(mlir::IRRewriter &rewriter,
BoxprocTypeRewriter &typeConverter) {
for (auto &op : getModule().getBody()->getOperations())
if (!mlir::isa<mlir::func::FuncOp>(op))
processOp(&op, rewriter, typeConverter);
}
/// Insert _FortranATrampolineFree calls before every return in the function.
void insertTrampolineFrees(mlir::func::FuncOp funcOp,
mlir::IRRewriter &rewriter) {
if (trampolineHandles.empty())
return;
auto module{funcOp->getParentOfType<mlir::ModuleOp>()};
// Insert TrampolineFree calls before every func.return in this function.
// At this pass stage (after CFGConversion), func.return is the only
// terminator that exits the function. Other terminators are either
// intra-function branches (cf.br, cf.cond_br, fir.select*) or
// fir.unreachable (after STOP/ERROR STOP), which don't need cleanup
// since the process is terminating.
funcOp.walk([&](mlir::func::ReturnOp retOp) {
rewriter.setInsertionPoint(retOp);
FirOpBuilder builder(rewriter, module);
auto loc{retOp.getLoc()};
for (mlir::Value handle : trampolineHandles)
fir::runtime::genTrampolineFree(builder, loc, handle);
});
}
/// Process a single operation for boxproc type rewriting.
void processOp(mlir::Operation *op, mlir::IRRewriter &rewriter,
BoxprocTypeRewriter &typeConverter) {
bool opIsValid{true};
typeConverter.setLocation(op->getLoc());
if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) {
mlir::Type ty{addr.getVal().getType()};
mlir::Type resTy{addr.getResult().getType()};
if (llvm::isa<mlir::FunctionType>(ty) ||
llvm::isa<fir::BoxProcType>(ty)) {
// Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc`
// or function type to be `fir.convert` ops.
rewriter.setInsertionPoint(addr);
rewriter.replaceOpWithNewOp<ConvertOp>(
addr, typeConverter.convertType(addr.getType()), addr.getVal());
opIsValid = false;
} else if (typeConverter.needsConversion(resTy)) {
rewriter.startOpModification(op);
op->getResult(0).setType(typeConverter.convertType(resTy));
rewriter.finalizeOpModification(op);
}
} else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) {
mlir::FunctionType ty{func.getFunctionType()};
if (typeConverter.needsConversion(ty)) {
rewriter.startOpModification(func);
auto toTy{
mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty))};
if (!func.empty())
for (auto e : llvm::enumerate(toTy.getInputs())) {
auto i{static_cast<unsigned>(e.index())};
auto &block{func.front()};
block.insertArgument(i, e.value(), func.getLoc());
block.getArgument(i + 1).replaceAllUsesWith(block.getArgument(i));
block.eraseArgument(i + 1);
}
func.setType(toTy);
rewriter.finalizeOpModification(func);
}
} else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) {
// Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk
// as required.
mlir::Type toTy{typeConverter.convertType(
mlir::cast<BoxProcType>(embox.getType()).getEleTy())};
rewriter.setInsertionPoint(embox);
if (embox.getHost()) {
auto module{embox->getParentOfType<mlir::ModuleOp>()};
auto loc{embox.getLoc()};
if (useSafeTrampoline) {
// Runtime trampoline pool path (W^X compliant).
// Insert Init/Adjust in the function's entry block so the
// handle dominates all func.return ops where TrampolineFree
// is emitted. This is necessary because fir.emboxproc may
// appear inside control flow branches. A cache avoids
// creating duplicate trampolines for the same internal
// procedure within a single host function.
mlir::Value funcVal{embox.getFunc()};
auto cacheIt{trampolineCallableMap.find(funcVal)};
if (cacheIt != trampolineCallableMap.end()) {
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
cacheIt->second);
} else {
auto parentFunc{embox->getParentOfType<mlir::func::FuncOp>()};
auto &entryBlock{parentFunc.front()};
auto savedIP{rewriter.saveInsertionPoint()};
// Find the right insertion point in the entry block.
// Walk up from the emboxproc to find its top-level
// ancestor in the entry block. For an emboxproc directly
// in the entry block, this is the emboxproc itself.
// For one inside a structured op (fir.if, fir.do_loop),
// this is that structured op. For one inside an explicit
// branch target (cf.cond_br → ^bb1), we fall back to the
// entry block terminator.
mlir::Operation *entryAncestor{embox.getOperation()};
while (entryAncestor->getBlock() != &entryBlock) {
entryAncestor = entryAncestor->getParentOp();
if (!entryAncestor ||
mlir::isa<mlir::func::FuncOp>(entryAncestor))
break;
}
bool ancestorInEntry{
entryAncestor &&
!mlir::isa<mlir::func::FuncOp>(entryAncestor) &&
entryAncestor->getBlock() == &entryBlock};
// If the func value is not in the entry block (e.g.,
// address_of generated inside a structured fir.if),
// clone it into the entry block.
mlir::Value funcValInEntry{funcVal};
if (auto *funcDef{funcVal.getDefiningOp()}) {
if (funcDef->getBlock() != &entryBlock) {
if (ancestorInEntry)
rewriter.setInsertionPoint(entryAncestor);
else
rewriter.setInsertionPoint(entryBlock.getTerminator());
auto *cloned{rewriter.clone(*funcDef)};
funcValInEntry = cloned->getResult(0);
}
}
// The host link (closure pointer) must already be in the entry
// block. In practice it is always either a function block argument
// or an alloca emitted at function entry by the lowering — cloning
// just the defining op would miss any stores that initialise it,
// producing incorrect code. Assert that invariant rather than
// attempting a broken clone.
mlir::Value hostValInEntry{embox.getHost()};
if (auto *hostDef{embox.getHost().getDefiningOp()}) {
if (hostDef->getBlock() != &entryBlock) {
mlir::emitError(loc,
"host link value is not defined in the entry "
"block of the host function; cannot hoist "
"TrampolineInit safely");
return;
}
}
// Insert Init/Adjust at the determined position.
FirOpBuilder builder(rewriter, module);
if (ancestorInEntry)
builder.setInsertionPoint(entryAncestor);
else
builder.setInsertionPoint(entryBlock.getTerminator());
mlir::Type i8Ty{builder.getI8Type()};
mlir::Type i8Ptr{builder.getRefType(i8Ty)};
mlir::Value nullPtr{builder.createNullConstant(loc, i8Ptr)};
mlir::Value closure{
builder.createConvert(loc, i8Ptr, hostValInEntry)};
mlir::Value func{builder.createConvert(loc, i8Ptr, funcValInEntry)};
// _FortranATrampolineInit(nullptr, func, closure) -> handle
mlir::Value handle{fir::runtime::genTrampolineInit(
builder, loc, nullPtr, func, closure)};
// _FortranATrampolineAdjust(handle) -> callable address
mlir::Value callableAddr{
fir::runtime::genTrampolineAdjust(builder, loc, handle)};
trampolineHandles.push_back(handle);
trampolineCallableMap[funcVal] = callableAddr;
rewriter.restoreInsertionPoint(savedIP);
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, callableAddr);
}
} else {
// Legacy stack-based trampoline path.
FirOpBuilder builder(rewriter, module);
mlir::Type i8Ty{builder.getI8Type()};
mlir::Type i8Ptr{builder.getRefType(i8Ty)};
const auto triple{fir::getTargetTriple(module)};
// For PPC32 and PPC64, the thunk is populated by a call to
// __trampoline_setup, which is defined in
// compiler-rt/lib/builtins/trampoline_setup.c and requires the
// thunk size greater than 32 bytes. For AArch64, RISCV and
// x86_64, the thunk setup doesn't go through
// __trampoline_setup and fits in 32 bytes.
fir::SequenceType::Extent thunkSize{triple.getTrampolineSize()};
mlir::Type buffTy{SequenceType::get({thunkSize}, i8Ty)};
auto buffer{AllocaOp::create(builder, loc, buffTy)};
mlir::Value closure{
builder.createConvert(loc, i8Ptr, embox.getHost())};
mlir::Value tramp{builder.createConvert(loc, i8Ptr, buffer)};
mlir::Value func{builder.createConvert(loc, i8Ptr, embox.getFunc())};
fir::CallOp::create(
builder, loc, factory::getLlvmInitTrampoline(builder),
llvm::ArrayRef<mlir::Value>{tramp, func, closure});
auto adjustCall{fir::CallOp::create(
builder, loc, factory::getLlvmAdjustTrampoline(builder),
llvm::ArrayRef<mlir::Value>{tramp})};
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy,
adjustCall.getResult(0));
}
opIsValid = false;
} else {
// Just forward the function as a pointer.
rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, embox.getFunc());
opIsValid = false;
}
} else if (auto global = mlir::dyn_cast<GlobalOp>(op)) {
auto ty{global.getType()};
if (typeConverter.needsConversion(ty)) {
rewriter.startOpModification(global);
auto toTy{typeConverter.convertType(ty)};
global.setType(toTy);
rewriter.finalizeOpModification(global);
}
} else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) {
auto ty{mem.getType()};
if (typeConverter.needsConversion(ty)) {
rewriter.setInsertionPoint(mem);
auto toTy{typeConverter.convertType(unwrapRefType(ty))};
bool isPinned{mem.getPinned()};
llvm::StringRef uniqName{mem.getUniqName().value_or(llvm::StringRef())};
llvm::StringRef bindcName{
mem.getBindcName().value_or(llvm::StringRef())};
rewriter.replaceOpWithNewOp<AllocaOp>(mem, toTy, uniqName, bindcName,
isPinned, mem.getTypeparams(),
mem.getShape());
opIsValid = false;
}
} else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) {
auto ty{mem.getType()};
if (typeConverter.needsConversion(ty)) {
rewriter.setInsertionPoint(mem);
auto toTy{typeConverter.convertType(unwrapRefType(ty))};
llvm::StringRef uniqName{mem.getUniqName().value_or(llvm::StringRef())};
llvm::StringRef bindcName{
mem.getBindcName().value_or(llvm::StringRef())};
rewriter.replaceOpWithNewOp<AllocMemOp>(mem, toTy, uniqName, bindcName,
mem.getTypeparams(),
mem.getShape());
opIsValid = false;
}
} else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) {
auto ty{coor.getType()};
mlir::Type baseTy{coor.getBaseType()};
if (typeConverter.needsConversion(ty) ||
typeConverter.needsConversion(baseTy)) {
rewriter.setInsertionPoint(coor);
auto toTy{typeConverter.convertType(ty)};
auto toBaseTy{typeConverter.convertType(baseTy)};
rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(),
coor.getCoor(), toBaseTy,
coor.getFieldIndicesAttr());
opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) {
auto ty{index.getType()};
mlir::Type onTy{index.getOnType()};
if (typeConverter.needsConversion(ty) ||
typeConverter.needsConversion(onTy)) {
rewriter.setInsertionPoint(index);
auto toTy{typeConverter.convertType(ty)};
auto toOnTy{typeConverter.convertType(onTy)};
rewriter.replaceOpWithNewOp<FieldIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
opIsValid = false;
}
} else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) {
auto ty{index.getType()};
mlir::Type onTy{index.getOnType()};
if (typeConverter.needsConversion(ty) ||
typeConverter.needsConversion(onTy)) {
rewriter.setInsertionPoint(index);
auto toTy{typeConverter.convertType(ty)};
auto toOnTy{typeConverter.convertType(onTy)};
rewriter.replaceOpWithNewOp<LenParamIndexOp>(
index, toTy, index.getFieldId(), toOnTy, index.getTypeparams());
opIsValid = false;
}
} else {
rewriter.startOpModification(op);
// Convert the operands if needed
for (auto i : llvm::enumerate(op->getResultTypes()))
if (typeConverter.needsConversion(i.value())) {
auto toTy{typeConverter.convertType(i.value())};
op->getResult(i.index()).setType(toTy);
}
// Convert the type attributes if needed
for (const mlir::NamedAttribute &attr : op->getAttrDictionary())
if (auto tyAttr = llvm::dyn_cast<mlir::TypeAttr>(attr.getValue()))
if (typeConverter.needsConversion(tyAttr.getValue())) {
auto toTy{typeConverter.convertType(tyAttr.getValue())};
op->setAttr(attr.getName(), mlir::TypeAttr::get(toTy));
}
rewriter.finalizeOpModification(op);
}
// Ensure block arguments are updated if needed.
if (opIsValid && op->getNumRegions() != 0) {
rewriter.startOpModification(op);
for (mlir::Region &region : op->getRegions())
for (mlir::Block &block : region.getBlocks())
for (mlir::BlockArgument blockArg : block.getArguments())
if (typeConverter.needsConversion(blockArg.getType())) {
mlir::Type toTy{typeConverter.convertType(blockArg.getType())};
blockArg.setType(toTy);
}
rewriter.finalizeOpModification(op);
}
}
};
} // namespace