| //===-- BoxedProcedure.cpp ------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "flang/Optimizer/CodeGen/CodeGen.h" |
| |
| #include "flang/Optimizer/Builder/FIRBuilder.h" |
| #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" |
| #include "flang/Optimizer/Dialect/FIRDialect.h" |
| #include "flang/Optimizer/Dialect/FIROps.h" |
| #include "flang/Optimizer/Dialect/FIRType.h" |
| #include "flang/Optimizer/Dialect/Support/FIRContext.h" |
| #include "flang/Optimizer/Support/FatalError.h" |
| #include "flang/Optimizer/Support/InternalNames.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "llvm/ADT/DenseMap.h" |
| |
| namespace fir { |
| #define GEN_PASS_DEF_BOXEDPROCEDUREPASS |
| #include "flang/Optimizer/CodeGen/CGPasses.h.inc" |
| } // namespace fir |
| |
| #define DEBUG_TYPE "flang-procedure-pointer" |
| |
| using namespace fir; |
| |
| namespace { |
| /// Options to the procedure pointer pass. |
| struct BoxedProcedureOptions { |
| // Lower the boxproc abstraction to function pointers and thunks where |
| // required. |
| bool useThunks = true; |
| }; |
| |
| /// This type converter rewrites all `!fir.boxproc<Func>` types to `Func` types. |
| class BoxprocTypeRewriter : public mlir::TypeConverter { |
| public: |
| using mlir::TypeConverter::convertType; |
| |
| /// Does the type \p ty need to be converted? |
| /// Any type that is a `!fir.boxproc` in whole or in part will need to be |
| /// converted to a function type to lower the IR to function pointer form in |
| /// the default implementation performed in this pass. Other implementations |
| /// are possible, so those may convert `!fir.boxproc` to some other type or |
| /// not at all depending on the implementation target's characteristics and |
| /// preference. |
| bool needsConversion(mlir::Type ty) { |
| if (mlir::isa<BoxProcType>(ty)) |
| return true; |
| if (auto funcTy = mlir::dyn_cast<mlir::FunctionType>(ty)) { |
| for (auto t : funcTy.getInputs()) |
| if (needsConversion(t)) |
| return true; |
| for (auto t : funcTy.getResults()) |
| if (needsConversion(t)) |
| return true; |
| return false; |
| } |
| if (auto tupleTy = mlir::dyn_cast<mlir::TupleType>(ty)) { |
| for (auto t : tupleTy.getTypes()) |
| if (needsConversion(t)) |
| return true; |
| return false; |
| } |
| if (auto recTy = mlir::dyn_cast<RecordType>(ty)) { |
| auto visited = visitedTypes.find(ty); |
| if (visited != visitedTypes.end()) |
| return visited->second; |
| [[maybe_unused]] auto newIt = visitedTypes.try_emplace(ty, false); |
| assert(newIt.second && "expected ty to not be in the map"); |
| bool wasAlreadyVisitingRecordType = needConversionIsVisitingRecordType; |
| needConversionIsVisitingRecordType = true; |
| bool result = false; |
| for (auto t : recTy.getTypeList()) { |
| if (needsConversion(t.second)) { |
| result = true; |
| break; |
| } |
| } |
| // Only keep the result cached if the fir.type visited was a "top-level |
| // type". Nested types with a recursive reference to the "top-level type" |
| // may incorrectly have been resolved as not needed conversions because it |
| // had not been determined yet if the "top-level type" needed conversion. |
| // This is not an issue to determine the "top-level type" need of |
| // conversion, but the result should not be kept and later used in other |
| // contexts. |
| needConversionIsVisitingRecordType = wasAlreadyVisitingRecordType; |
| if (needConversionIsVisitingRecordType) |
| visitedTypes.erase(ty); |
| else |
| visitedTypes.find(ty)->second = result; |
| return result; |
| } |
| if (auto boxTy = mlir::dyn_cast<BaseBoxType>(ty)) |
| return needsConversion(boxTy.getEleTy()); |
| if (isa_ref_type(ty)) |
| return needsConversion(unwrapRefType(ty)); |
| if (auto t = mlir::dyn_cast<SequenceType>(ty)) |
| return needsConversion(unwrapSequenceType(ty)); |
| return false; |
| } |
| |
| BoxprocTypeRewriter(mlir::Location location) : loc{location} { |
| addConversion([](mlir::Type ty) { return ty; }); |
| addConversion( |
| [&](BoxProcType boxproc) { return convertType(boxproc.getEleTy()); }); |
| addConversion([&](mlir::TupleType tupTy) { |
| llvm::SmallVector<mlir::Type> memTys; |
| for (auto ty : tupTy.getTypes()) |
| memTys.push_back(convertType(ty)); |
| return mlir::TupleType::get(tupTy.getContext(), memTys); |
| }); |
| addConversion([&](mlir::FunctionType funcTy) { |
| llvm::SmallVector<mlir::Type> inTys; |
| llvm::SmallVector<mlir::Type> resTys; |
| for (auto ty : funcTy.getInputs()) |
| inTys.push_back(convertType(ty)); |
| for (auto ty : funcTy.getResults()) |
| resTys.push_back(convertType(ty)); |
| return mlir::FunctionType::get(funcTy.getContext(), inTys, resTys); |
| }); |
| addConversion([&](ReferenceType ty) { |
| return ReferenceType::get(convertType(ty.getEleTy())); |
| }); |
| addConversion([&](PointerType ty) { |
| return PointerType::get(convertType(ty.getEleTy())); |
| }); |
| addConversion( |
| [&](HeapType ty) { return HeapType::get(convertType(ty.getEleTy())); }); |
| addConversion([&](fir::LLVMPointerType ty) { |
| return fir::LLVMPointerType::get(convertType(ty.getEleTy())); |
| }); |
| addConversion( |
| [&](BoxType ty) { return BoxType::get(convertType(ty.getEleTy())); }); |
| addConversion([&](ClassType ty) { |
| return ClassType::get(convertType(ty.getEleTy())); |
| }); |
| addConversion([&](SequenceType ty) { |
| // TODO: add ty.getLayoutMap() as needed. |
| return SequenceType::get(ty.getShape(), convertType(ty.getEleTy())); |
| }); |
| addConversion([&](RecordType ty) -> mlir::Type { |
| if (!needsConversion(ty)) |
| return ty; |
| if (auto converted = convertedTypes.lookup(ty)) |
| return converted; |
| auto rec = RecordType::get(ty.getContext(), |
| ty.getName().str() + boxprocSuffix.str()); |
| if (rec.isFinalized()) |
| return rec; |
| [[maybe_unused]] auto it = convertedTypes.try_emplace(ty, rec); |
| assert(it.second && "expected ty to not be in the map"); |
| std::vector<RecordType::TypePair> ps = ty.getLenParamList(); |
| std::vector<RecordType::TypePair> cs; |
| for (auto t : ty.getTypeList()) { |
| if (needsConversion(t.second)) |
| cs.emplace_back(t.first, convertType(t.second)); |
| else |
| cs.emplace_back(t.first, t.second); |
| } |
| rec.finalize(ps, cs); |
| return rec; |
| }); |
| addArgumentMaterialization(materializeProcedure); |
| addSourceMaterialization(materializeProcedure); |
| addTargetMaterialization(materializeProcedure); |
| } |
| |
| static mlir::Value materializeProcedure(mlir::OpBuilder &builder, |
| BoxProcType type, |
| mlir::ValueRange inputs, |
| mlir::Location loc) { |
| assert(inputs.size() == 1); |
| return builder.create<ConvertOp>(loc, unwrapRefType(type.getEleTy()), |
| inputs[0]); |
| } |
| |
| void setLocation(mlir::Location location) { loc = location; } |
| |
| private: |
| // Maps to deal with recursive derived types (avoid infinite loops). |
| // Caching is also beneficial for apps with big types (dozens of |
| // components and or parent types), so the lifetime of the cache |
| // is the whole pass. |
| llvm::DenseMap<mlir::Type, bool> visitedTypes; |
| bool needConversionIsVisitingRecordType = false; |
| llvm::DenseMap<mlir::Type, mlir::Type> convertedTypes; |
| mlir::Location loc; |
| }; |
| |
| /// A `boxproc` is an abstraction for a Fortran procedure reference. Typically, |
| /// Fortran procedures can be referenced directly through a function pointer. |
| /// However, Fortran has one-level dynamic scoping between a host procedure and |
| /// its internal procedures. This allows internal procedures to directly access |
| /// and modify the state of the host procedure's variables. |
| /// |
| /// There are any number of possible implementations possible. |
| /// |
| /// The implementation used here is to convert `boxproc` values to function |
| /// pointers everywhere. If a `boxproc` value includes a frame pointer to the |
| /// host procedure's data, then a thunk will be created at runtime to capture |
| /// the frame pointer during execution. In LLVM IR, the frame pointer is |
| /// designated with the `nest` attribute. The thunk's address will then be used |
| /// as the call target instead of the original function's address directly. |
| class BoxedProcedurePass |
| : public fir::impl::BoxedProcedurePassBase<BoxedProcedurePass> { |
| public: |
| BoxedProcedurePass() { options = {true}; } |
| BoxedProcedurePass(bool useThunks) { options = {useThunks}; } |
| |
| inline mlir::ModuleOp getModule() { return getOperation(); } |
| |
| void runOnOperation() override final { |
| if (options.useThunks) { |
| auto *context = &getContext(); |
| mlir::IRRewriter rewriter(context); |
| BoxprocTypeRewriter typeConverter(mlir::UnknownLoc::get(context)); |
| mlir::Dialect *firDialect = context->getLoadedDialect("fir"); |
| getModule().walk([&](mlir::Operation *op) { |
| bool opIsValid = true; |
| typeConverter.setLocation(op->getLoc()); |
| if (auto addr = mlir::dyn_cast<BoxAddrOp>(op)) { |
| mlir::Type ty = addr.getVal().getType(); |
| mlir::Type resTy = addr.getResult().getType(); |
| if (llvm::isa<mlir::FunctionType>(ty) || |
| llvm::isa<fir::BoxProcType>(ty)) { |
| // Rewrite all `fir.box_addr` ops on values of type `!fir.boxproc` |
| // or function type to be `fir.convert` ops. |
| rewriter.setInsertionPoint(addr); |
| rewriter.replaceOpWithNewOp<ConvertOp>( |
| addr, typeConverter.convertType(addr.getType()), addr.getVal()); |
| opIsValid = false; |
| } else if (typeConverter.needsConversion(resTy)) { |
| rewriter.startOpModification(op); |
| op->getResult(0).setType(typeConverter.convertType(resTy)); |
| rewriter.finalizeOpModification(op); |
| } |
| } else if (auto func = mlir::dyn_cast<mlir::func::FuncOp>(op)) { |
| mlir::FunctionType ty = func.getFunctionType(); |
| if (typeConverter.needsConversion(ty)) { |
| rewriter.startOpModification(func); |
| auto toTy = |
| mlir::cast<mlir::FunctionType>(typeConverter.convertType(ty)); |
| if (!func.empty()) |
| for (auto e : llvm::enumerate(toTy.getInputs())) { |
| unsigned i = e.index(); |
| auto &block = func.front(); |
| block.insertArgument(i, e.value(), func.getLoc()); |
| block.getArgument(i + 1).replaceAllUsesWith( |
| block.getArgument(i)); |
| block.eraseArgument(i + 1); |
| } |
| func.setType(toTy); |
| rewriter.finalizeOpModification(func); |
| } |
| } else if (auto embox = mlir::dyn_cast<EmboxProcOp>(op)) { |
| // Rewrite all `fir.emboxproc` ops to either `fir.convert` or a thunk |
| // as required. |
| mlir::Type toTy = typeConverter.convertType( |
| mlir::cast<BoxProcType>(embox.getType()).getEleTy()); |
| rewriter.setInsertionPoint(embox); |
| if (embox.getHost()) { |
| // Create the thunk. |
| auto module = embox->getParentOfType<mlir::ModuleOp>(); |
| FirOpBuilder builder(rewriter, module); |
| auto loc = embox.getLoc(); |
| mlir::Type i8Ty = builder.getI8Type(); |
| mlir::Type i8Ptr = builder.getRefType(i8Ty); |
| mlir::Type buffTy = SequenceType::get({32}, i8Ty); |
| auto buffer = builder.create<AllocaOp>(loc, buffTy); |
| mlir::Value closure = |
| builder.createConvert(loc, i8Ptr, embox.getHost()); |
| mlir::Value tramp = builder.createConvert(loc, i8Ptr, buffer); |
| mlir::Value func = |
| builder.createConvert(loc, i8Ptr, embox.getFunc()); |
| builder.create<fir::CallOp>( |
| loc, factory::getLlvmInitTrampoline(builder), |
| llvm::ArrayRef<mlir::Value>{tramp, func, closure}); |
| auto adjustCall = builder.create<fir::CallOp>( |
| loc, factory::getLlvmAdjustTrampoline(builder), |
| llvm::ArrayRef<mlir::Value>{tramp}); |
| rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
| adjustCall.getResult(0)); |
| opIsValid = false; |
| } else { |
| // Just forward the function as a pointer. |
| rewriter.replaceOpWithNewOp<ConvertOp>(embox, toTy, |
| embox.getFunc()); |
| opIsValid = false; |
| } |
| } else if (auto global = mlir::dyn_cast<GlobalOp>(op)) { |
| auto ty = global.getType(); |
| if (typeConverter.needsConversion(ty)) { |
| rewriter.startOpModification(global); |
| auto toTy = typeConverter.convertType(ty); |
| global.setType(toTy); |
| rewriter.finalizeOpModification(global); |
| } |
| } else if (auto mem = mlir::dyn_cast<AllocaOp>(op)) { |
| auto ty = mem.getType(); |
| if (typeConverter.needsConversion(ty)) { |
| rewriter.setInsertionPoint(mem); |
| auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
| bool isPinned = mem.getPinned(); |
| llvm::StringRef uniqName = |
| mem.getUniqName().value_or(llvm::StringRef()); |
| llvm::StringRef bindcName = |
| mem.getBindcName().value_or(llvm::StringRef()); |
| rewriter.replaceOpWithNewOp<AllocaOp>( |
| mem, toTy, uniqName, bindcName, isPinned, mem.getTypeparams(), |
| mem.getShape()); |
| opIsValid = false; |
| } |
| } else if (auto mem = mlir::dyn_cast<AllocMemOp>(op)) { |
| auto ty = mem.getType(); |
| if (typeConverter.needsConversion(ty)) { |
| rewriter.setInsertionPoint(mem); |
| auto toTy = typeConverter.convertType(unwrapRefType(ty)); |
| llvm::StringRef uniqName = |
| mem.getUniqName().value_or(llvm::StringRef()); |
| llvm::StringRef bindcName = |
| mem.getBindcName().value_or(llvm::StringRef()); |
| rewriter.replaceOpWithNewOp<AllocMemOp>( |
| mem, toTy, uniqName, bindcName, mem.getTypeparams(), |
| mem.getShape()); |
| opIsValid = false; |
| } |
| } else if (auto coor = mlir::dyn_cast<CoordinateOp>(op)) { |
| auto ty = coor.getType(); |
| mlir::Type baseTy = coor.getBaseType(); |
| if (typeConverter.needsConversion(ty) || |
| typeConverter.needsConversion(baseTy)) { |
| rewriter.setInsertionPoint(coor); |
| auto toTy = typeConverter.convertType(ty); |
| auto toBaseTy = typeConverter.convertType(baseTy); |
| rewriter.replaceOpWithNewOp<CoordinateOp>(coor, toTy, coor.getRef(), |
| coor.getCoor(), toBaseTy); |
| opIsValid = false; |
| } |
| } else if (auto index = mlir::dyn_cast<FieldIndexOp>(op)) { |
| auto ty = index.getType(); |
| mlir::Type onTy = index.getOnType(); |
| if (typeConverter.needsConversion(ty) || |
| typeConverter.needsConversion(onTy)) { |
| rewriter.setInsertionPoint(index); |
| auto toTy = typeConverter.convertType(ty); |
| auto toOnTy = typeConverter.convertType(onTy); |
| rewriter.replaceOpWithNewOp<FieldIndexOp>( |
| index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
| opIsValid = false; |
| } |
| } else if (auto index = mlir::dyn_cast<LenParamIndexOp>(op)) { |
| auto ty = index.getType(); |
| mlir::Type onTy = index.getOnType(); |
| if (typeConverter.needsConversion(ty) || |
| typeConverter.needsConversion(onTy)) { |
| rewriter.setInsertionPoint(index); |
| auto toTy = typeConverter.convertType(ty); |
| auto toOnTy = typeConverter.convertType(onTy); |
| rewriter.replaceOpWithNewOp<LenParamIndexOp>( |
| index, toTy, index.getFieldId(), toOnTy, index.getTypeparams()); |
| opIsValid = false; |
| } |
| } else if (op->getDialect() == firDialect) { |
| rewriter.startOpModification(op); |
| for (auto i : llvm::enumerate(op->getResultTypes())) |
| if (typeConverter.needsConversion(i.value())) { |
| auto toTy = typeConverter.convertType(i.value()); |
| op->getResult(i.index()).setType(toTy); |
| } |
| rewriter.finalizeOpModification(op); |
| } |
| // Ensure block arguments are updated if needed. |
| if (opIsValid && op->getNumRegions() != 0) { |
| rewriter.startOpModification(op); |
| for (mlir::Region ®ion : op->getRegions()) |
| for (mlir::Block &block : region.getBlocks()) |
| for (mlir::BlockArgument blockArg : block.getArguments()) |
| if (typeConverter.needsConversion(blockArg.getType())) { |
| mlir::Type toTy = |
| typeConverter.convertType(blockArg.getType()); |
| blockArg.setType(toTy); |
| } |
| rewriter.finalizeOpModification(op); |
| } |
| }); |
| } |
| } |
| |
| private: |
| BoxedProcedureOptions options; |
| }; |
| } // namespace |
| |
| std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass() { |
| return std::make_unique<BoxedProcedurePass>(); |
| } |
| |
| std::unique_ptr<mlir::Pass> fir::createBoxedProcedurePass(bool useThunks) { |
| return std::make_unique<BoxedProcedurePass>(useThunks); |
| } |