| //===-- TargetRewrite.cpp -------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Target rewrite: rewriting of ops to make target-specific lowerings manifest. |
| // LLVM expects different lowering idioms to be used for distinct target |
| // triples. These distinctions are handled by this pass. |
| // |
| // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "Target.h" |
| #include "flang/Lower/Todo.h" |
| #include "flang/Optimizer/CodeGen/CodeGen.h" |
| #include "flang/Optimizer/Dialect/FIRDialect.h" |
| #include "flang/Optimizer/Dialect/FIROps.h" |
| #include "flang/Optimizer/Dialect/FIRType.h" |
| #include "flang/Optimizer/Support/FIRContext.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace fir; |
| |
| #define DEBUG_TYPE "flang-target-rewrite" |
| |
| namespace { |
| |
| /// Fixups for updating a FuncOp's arguments and return values. |
| struct FixupTy { |
| enum class Codes { |
| ArgumentAsLoad, |
| ArgumentType, |
| CharPair, |
| ReturnAsStore, |
| ReturnType, |
| Split, |
| Trailing |
| }; |
| |
| FixupTy(Codes code, std::size_t index, std::size_t second = 0) |
| : code{code}, index{index}, second{second} {} |
| FixupTy(Codes code, std::size_t index, |
| std::function<void(mlir::FuncOp)> &&finalizer) |
| : code{code}, index{index}, finalizer{finalizer} {} |
| FixupTy(Codes code, std::size_t index, std::size_t second, |
| std::function<void(mlir::FuncOp)> &&finalizer) |
| : code{code}, index{index}, second{second}, finalizer{finalizer} {} |
| |
| Codes code; |
| std::size_t index; |
| std::size_t second{}; |
| llvm::Optional<std::function<void(mlir::FuncOp)>> finalizer{}; |
| }; // namespace |
| |
| /// Target-specific rewriting of the FIR. This is a prerequisite pass to code |
| /// generation that traverses the FIR and modifies types and operations to a |
| /// form that is appropriate for the specific target. LLVM IR has specific |
| /// idioms that are used for distinct target processor and ABI combinations. |
| class TargetRewrite : public TargetRewriteBase<TargetRewrite> { |
| public: |
| TargetRewrite(const TargetRewriteOptions &options) { |
| noCharacterConversion = options.noCharacterConversion; |
| noComplexConversion = options.noComplexConversion; |
| } |
| |
| void runOnOperation() override final { |
| auto &context = getContext(); |
| mlir::OpBuilder rewriter(&context); |
| |
| auto mod = getModule(); |
| if (!forcedTargetTriple.empty()) { |
| setTargetTriple(mod, forcedTargetTriple); |
| } |
| |
| auto specifics = CodeGenSpecifics::get(getOperation().getContext(), |
| getTargetTriple(getOperation()), |
| getKindMapping(getOperation())); |
| setMembers(specifics.get(), &rewriter); |
| |
| // Perform type conversion on signatures and call sites. |
| if (mlir::failed(convertTypes(mod))) { |
| mlir::emitError(mlir::UnknownLoc::get(&context), |
| "error in converting types to target abi"); |
| signalPassFailure(); |
| } |
| |
| // Convert ops in target-specific patterns. |
| mod.walk([&](mlir::Operation *op) { |
| if (auto call = dyn_cast<fir::CallOp>(op)) { |
| if (!hasPortableSignature(call.getFunctionType())) |
| convertCallOp(call); |
| } else if (auto dispatch = dyn_cast<DispatchOp>(op)) { |
| if (!hasPortableSignature(dispatch.getFunctionType())) |
| convertCallOp(dispatch); |
| } |
| }); |
| |
| clearMembers(); |
| } |
| |
| mlir::ModuleOp getModule() { return getOperation(); } |
| |
| template <typename A, typename B, typename C> |
| std::function<mlir::Value(mlir::Operation *)> |
| rewriteCallComplexResultType(A ty, B &newResTys, B &newInTys, C &newOpers) { |
| auto m = specifics->complexReturnType(ty.getElementType()); |
| // Currently targets mandate COMPLEX is a single aggregate or packed |
| // scalar, including the sret case. |
| assert(m.size() == 1 && "target lowering of complex return not supported"); |
| auto resTy = std::get<mlir::Type>(m[0]); |
| auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); |
| auto loc = mlir::UnknownLoc::get(resTy.getContext()); |
| if (attr.isSRet()) { |
| assert(isa_ref_type(resTy)); |
| mlir::Value stack = |
| rewriter->create<fir::AllocaOp>(loc, dyn_cast_ptrEleTy(resTy)); |
| newInTys.push_back(resTy); |
| newOpers.push_back(stack); |
| return [=](mlir::Operation *) -> mlir::Value { |
| auto memTy = ReferenceType::get(ty); |
| auto cast = rewriter->create<ConvertOp>(loc, memTy, stack); |
| return rewriter->create<fir::LoadOp>(loc, cast); |
| }; |
| } |
| newResTys.push_back(resTy); |
| return [=](mlir::Operation *call) -> mlir::Value { |
| auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); |
| rewriter->create<fir::StoreOp>(loc, call->getResult(0), mem); |
| auto memTy = ReferenceType::get(ty); |
| auto cast = rewriter->create<ConvertOp>(loc, memTy, mem); |
| return rewriter->create<fir::LoadOp>(loc, cast); |
| }; |
| } |
| |
| template <typename A, typename B, typename C> |
| void rewriteCallComplexInputType(A ty, mlir::Value oper, B &newInTys, |
| C &newOpers) { |
| auto m = specifics->complexArgumentType(ty.getElementType()); |
| auto *ctx = ty.getContext(); |
| auto loc = mlir::UnknownLoc::get(ctx); |
| if (m.size() == 1) { |
| // COMPLEX is a single aggregate |
| auto resTy = std::get<mlir::Type>(m[0]); |
| auto attr = std::get<CodeGenSpecifics::Attributes>(m[0]); |
| auto oldRefTy = ReferenceType::get(ty); |
| if (attr.isByVal()) { |
| auto mem = rewriter->create<fir::AllocaOp>(loc, ty); |
| rewriter->create<fir::StoreOp>(loc, oper, mem); |
| newOpers.push_back(rewriter->create<ConvertOp>(loc, resTy, mem)); |
| } else { |
| auto mem = rewriter->create<fir::AllocaOp>(loc, resTy); |
| auto cast = rewriter->create<ConvertOp>(loc, oldRefTy, mem); |
| rewriter->create<fir::StoreOp>(loc, oper, cast); |
| newOpers.push_back(rewriter->create<fir::LoadOp>(loc, mem)); |
| } |
| newInTys.push_back(resTy); |
| } else { |
| assert(m.size() == 2); |
| // COMPLEX is split into 2 separate arguments |
| for (auto e : llvm::enumerate(m)) { |
| auto &tup = e.value(); |
| auto ty = std::get<mlir::Type>(tup); |
| auto index = e.index(); |
| auto idx = rewriter->getIntegerAttr(rewriter->getIndexType(), index); |
| auto val = rewriter->create<ExtractValueOp>( |
| loc, ty, oper, rewriter->getArrayAttr(idx)); |
| newInTys.push_back(ty); |
| newOpers.push_back(val); |
| } |
| } |
| } |
| |
| // Convert fir.call and fir.dispatch Ops. |
| template <typename A> |
| void convertCallOp(A callOp) { |
| auto fnTy = callOp.getFunctionType(); |
| auto loc = callOp.getLoc(); |
| rewriter->setInsertionPoint(callOp); |
| llvm::SmallVector<mlir::Type> newResTys; |
| llvm::SmallVector<mlir::Type> newInTys; |
| llvm::SmallVector<mlir::Value> newOpers; |
| |
| // If the call is indirect, the first argument must still be the function |
| // to call. |
| int dropFront = 0; |
| if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { |
| if (!callOp.callee().hasValue()) { |
| newInTys.push_back(fnTy.getInput(0)); |
| newOpers.push_back(callOp.getOperand(0)); |
| dropFront = 1; |
| } |
| } |
| |
| // Determine the rewrite function, `wrap`, for the result value. |
| llvm::Optional<std::function<mlir::Value(mlir::Operation *)>> wrap; |
| if (fnTy.getResults().size() == 1) { |
| mlir::Type ty = fnTy.getResult(0); |
| llvm::TypeSwitch<mlir::Type>(ty) |
| .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { |
| wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, |
| newOpers); |
| }) |
| .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { |
| wrap = rewriteCallComplexResultType(cmplx, newResTys, newInTys, |
| newOpers); |
| }) |
| .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); |
| } else if (fnTy.getResults().size() > 1) { |
| TODO(loc, "multiple results not supported yet"); |
| } |
| |
| llvm::SmallVector<mlir::Type> trailingInTys; |
| llvm::SmallVector<mlir::Value> trailingOpers; |
| for (auto e : llvm::enumerate( |
| llvm::zip(fnTy.getInputs().drop_front(dropFront), |
| callOp.getOperands().drop_front(dropFront)))) { |
| mlir::Type ty = std::get<0>(e.value()); |
| mlir::Value oper = std::get<1>(e.value()); |
| unsigned index = e.index(); |
| llvm::TypeSwitch<mlir::Type>(ty) |
| .template Case<BoxCharType>([&](BoxCharType boxTy) { |
| bool sret; |
| if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { |
| sret = callOp.callee() && |
| functionArgIsSRet(index, |
| getModule().lookupSymbol<mlir::FuncOp>( |
| *callOp.callee())); |
| } else { |
| // TODO: dispatch case; how do we put arguments on a call? |
| // We cannot put both an sret and the dispatch object first. |
| sret = false; |
| TODO(loc, "dispatch + sret not supported yet"); |
| } |
| auto m = specifics->boxcharArgumentType(boxTy.getEleTy(), sret); |
| auto unbox = |
| rewriter->create<UnboxCharOp>(loc, std::get<mlir::Type>(m[0]), |
| std::get<mlir::Type>(m[1]), oper); |
| // unboxed CHARACTER arguments |
| for (auto e : llvm::enumerate(m)) { |
| unsigned idx = e.index(); |
| auto attr = std::get<CodeGenSpecifics::Attributes>(e.value()); |
| auto argTy = std::get<mlir::Type>(e.value()); |
| if (attr.isAppend()) { |
| trailingInTys.push_back(argTy); |
| trailingOpers.push_back(unbox.getResult(idx)); |
| } else { |
| newInTys.push_back(argTy); |
| newOpers.push_back(unbox.getResult(idx)); |
| } |
| } |
| }) |
| .template Case<fir::ComplexType>([&](fir::ComplexType cmplx) { |
| rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); |
| }) |
| .template Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { |
| rewriteCallComplexInputType(cmplx, oper, newInTys, newOpers); |
| }) |
| .Default([&](mlir::Type ty) { |
| newInTys.push_back(ty); |
| newOpers.push_back(oper); |
| }); |
| } |
| newInTys.insert(newInTys.end(), trailingInTys.begin(), trailingInTys.end()); |
| newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); |
| if constexpr (std::is_same_v<std::decay_t<A>, fir::CallOp>) { |
| fir::CallOp newCall; |
| if (callOp.callee().hasValue()) { |
| newCall = rewriter->create<A>(loc, callOp.callee().getValue(), |
| newResTys, newOpers); |
| } else { |
| // Force new type on the input operand. |
| newOpers[0].setType(mlir::FunctionType::get( |
| callOp.getContext(), |
| mlir::TypeRange{newInTys}.drop_front(dropFront), newResTys)); |
| newCall = rewriter->create<A>(loc, newResTys, newOpers); |
| } |
| LLVM_DEBUG(llvm::dbgs() << "replacing call with " << newCall << '\n'); |
| if (wrap.hasValue()) |
| replaceOp(callOp, (*wrap)(newCall.getOperation())); |
| else |
| replaceOp(callOp, newCall.getResults()); |
| } else { |
| // A is fir::DispatchOp |
| TODO(loc, "dispatch not implemented"); |
| } |
| } |
| |
| // Result type fixup for fir::ComplexType and mlir::ComplexType |
| template <typename A, typename B> |
| void lowerComplexSignatureRes(A cmplx, B &newResTys, B &newInTys) { |
| if (noComplexConversion) { |
| newResTys.push_back(cmplx); |
| } else { |
| for (auto &tup : specifics->complexReturnType(cmplx.getElementType())) { |
| auto argTy = std::get<mlir::Type>(tup); |
| if (std::get<CodeGenSpecifics::Attributes>(tup).isSRet()) |
| newInTys.push_back(argTy); |
| else |
| newResTys.push_back(argTy); |
| } |
| } |
| } |
| |
| // Argument type fixup for fir::ComplexType and mlir::ComplexType |
| template <typename A, typename B> |
| void lowerComplexSignatureArg(A cmplx, B &newInTys) { |
| if (noComplexConversion) |
| newInTys.push_back(cmplx); |
| else |
| for (auto &tup : specifics->complexArgumentType(cmplx.getElementType())) |
| newInTys.push_back(std::get<mlir::Type>(tup)); |
| } |
| |
| /// Convert the type signatures on all the functions present in the module. |
| /// As the type signature is being changed, this must also update the |
| /// function itself to use any new arguments, etc. |
| mlir::LogicalResult convertTypes(mlir::ModuleOp mod) { |
| for (auto fn : mod.getOps<mlir::FuncOp>()) |
| convertSignature(fn); |
| return mlir::success(); |
| } |
| |
| /// If the signature does not need any special target-specific converions, |
| /// then it is considered portable for any target, and this function will |
| /// return `true`. Otherwise, the signature is not portable and `false` is |
| /// returned. |
| bool hasPortableSignature(mlir::Type signature) { |
| assert(signature.isa<mlir::FunctionType>()); |
| auto func = signature.dyn_cast<mlir::FunctionType>(); |
| for (auto ty : func.getResults()) |
| if ((ty.isa<BoxCharType>() && !noCharacterConversion) || |
| (isa_complex(ty) && !noComplexConversion)) { |
| LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); |
| return false; |
| } |
| for (auto ty : func.getInputs()) |
| if ((ty.isa<BoxCharType>() && !noCharacterConversion) || |
| (isa_complex(ty) && !noComplexConversion)) { |
| LLVM_DEBUG(llvm::dbgs() << "rewrite " << signature << " for target\n"); |
| return false; |
| } |
| return true; |
| } |
| |
| /// Rewrite the signatures and body of the `FuncOp`s in the module for |
| /// the immediately subsequent target code gen. |
| void convertSignature(mlir::FuncOp func) { |
| auto funcTy = func.getType().cast<mlir::FunctionType>(); |
| if (hasPortableSignature(funcTy)) |
| return; |
| llvm::SmallVector<mlir::Type> newResTys; |
| llvm::SmallVector<mlir::Type> newInTys; |
| llvm::SmallVector<FixupTy> fixups; |
| |
| // Convert return value(s) |
| for (auto ty : funcTy.getResults()) |
| llvm::TypeSwitch<mlir::Type>(ty) |
| .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { |
| if (noComplexConversion) |
| newResTys.push_back(cmplx); |
| else |
| doComplexReturn(func, cmplx, newResTys, newInTys, fixups); |
| }) |
| .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { |
| if (noComplexConversion) |
| newResTys.push_back(cmplx); |
| else |
| doComplexReturn(func, cmplx, newResTys, newInTys, fixups); |
| }) |
| .Default([&](mlir::Type ty) { newResTys.push_back(ty); }); |
| |
| // Convert arguments |
| llvm::SmallVector<mlir::Type> trailingTys; |
| for (auto e : llvm::enumerate(funcTy.getInputs())) { |
| auto ty = e.value(); |
| unsigned index = e.index(); |
| llvm::TypeSwitch<mlir::Type>(ty) |
| .Case<BoxCharType>([&](BoxCharType boxTy) { |
| if (noCharacterConversion) { |
| newInTys.push_back(boxTy); |
| } else { |
| // Convert a CHARACTER argument type. This can involve separating |
| // the pointer and the LEN into two arguments and moving the LEN |
| // argument to the end of the arg list. |
| bool sret = functionArgIsSRet(index, func); |
| for (auto e : llvm::enumerate(specifics->boxcharArgumentType( |
| boxTy.getEleTy(), sret))) { |
| auto &tup = e.value(); |
| auto index = e.index(); |
| auto attr = std::get<CodeGenSpecifics::Attributes>(tup); |
| auto argTy = std::get<mlir::Type>(tup); |
| if (attr.isAppend()) { |
| trailingTys.push_back(argTy); |
| } else { |
| if (sret) { |
| fixups.emplace_back(FixupTy::Codes::CharPair, |
| newInTys.size(), index); |
| } else { |
| fixups.emplace_back(FixupTy::Codes::Trailing, |
| newInTys.size(), trailingTys.size()); |
| } |
| newInTys.push_back(argTy); |
| } |
| } |
| } |
| }) |
| .Case<fir::ComplexType>([&](fir::ComplexType cmplx) { |
| if (noComplexConversion) |
| newInTys.push_back(cmplx); |
| else |
| doComplexArg(func, cmplx, newInTys, fixups); |
| }) |
| .Case<mlir::ComplexType>([&](mlir::ComplexType cmplx) { |
| if (noComplexConversion) |
| newInTys.push_back(cmplx); |
| else |
| doComplexArg(func, cmplx, newInTys, fixups); |
| }) |
| .Default([&](mlir::Type ty) { newInTys.push_back(ty); }); |
| } |
| |
| if (!func.empty()) { |
| // If the function has a body, then apply the fixups to the arguments and |
| // return ops as required. These fixups are done in place. |
| auto loc = func.getLoc(); |
| const auto fixupSize = fixups.size(); |
| const auto oldArgTys = func.getType().getInputs(); |
| int offset = 0; |
| for (std::remove_const_t<decltype(fixupSize)> i = 0; i < fixupSize; ++i) { |
| const auto &fixup = fixups[i]; |
| switch (fixup.code) { |
| case FixupTy::Codes::ArgumentAsLoad: { |
| // Argument was pass-by-value, but is now pass-by-reference and |
| // possibly with a different element type. |
| auto newArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| rewriter->setInsertionPointToStart(&func.front()); |
| auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); |
| auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, newArg); |
| auto load = rewriter->create<fir::LoadOp>(loc, cast); |
| func.getArgument(fixup.index + 1).replaceAllUsesWith(load); |
| func.front().eraseArgument(fixup.index + 1); |
| } break; |
| case FixupTy::Codes::ArgumentType: { |
| // Argument is pass-by-value, but its type has likely been modified to |
| // suit the target ABI convention. |
| auto newArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| rewriter->setInsertionPointToStart(&func.front()); |
| auto mem = |
| rewriter->create<fir::AllocaOp>(loc, newInTys[fixup.index]); |
| rewriter->create<fir::StoreOp>(loc, newArg, mem); |
| auto oldArgTy = ReferenceType::get(oldArgTys[fixup.index - offset]); |
| auto cast = rewriter->create<ConvertOp>(loc, oldArgTy, mem); |
| mlir::Value load = rewriter->create<fir::LoadOp>(loc, cast); |
| func.getArgument(fixup.index + 1).replaceAllUsesWith(load); |
| func.front().eraseArgument(fixup.index + 1); |
| LLVM_DEBUG(llvm::dbgs() |
| << "old argument: " << oldArgTy.getEleTy() |
| << ", repl: " << load << ", new argument: " |
| << func.getArgument(fixup.index).getType() << '\n'); |
| } break; |
| case FixupTy::Codes::CharPair: { |
| // The FIR boxchar argument has been split into a pair of distinct |
| // arguments that are in juxtaposition to each other. |
| auto newArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| if (fixup.second == 1) { |
| rewriter->setInsertionPointToStart(&func.front()); |
| auto boxTy = oldArgTys[fixup.index - offset - fixup.second]; |
| auto box = rewriter->create<EmboxCharOp>( |
| loc, boxTy, func.front().getArgument(fixup.index - 1), newArg); |
| func.getArgument(fixup.index + 1).replaceAllUsesWith(box); |
| func.front().eraseArgument(fixup.index + 1); |
| offset++; |
| } |
| } break; |
| case FixupTy::Codes::ReturnAsStore: { |
| // The value being returned is now being returned in memory (callee |
| // stack space) through a hidden reference argument. |
| auto newArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| offset++; |
| func.walk([&](mlir::ReturnOp ret) { |
| rewriter->setInsertionPoint(ret); |
| auto oldOper = ret.getOperand(0); |
| auto oldOperTy = ReferenceType::get(oldOper.getType()); |
| auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, newArg); |
| rewriter->create<fir::StoreOp>(loc, oldOper, cast); |
| rewriter->create<mlir::ReturnOp>(loc); |
| ret.erase(); |
| }); |
| } break; |
| case FixupTy::Codes::ReturnType: { |
| // The function is still returning a value, but its type has likely |
| // changed to suit the target ABI convention. |
| func.walk([&](mlir::ReturnOp ret) { |
| rewriter->setInsertionPoint(ret); |
| auto oldOper = ret.getOperand(0); |
| auto oldOperTy = ReferenceType::get(oldOper.getType()); |
| auto mem = |
| rewriter->create<fir::AllocaOp>(loc, newResTys[fixup.index]); |
| auto cast = rewriter->create<ConvertOp>(loc, oldOperTy, mem); |
| rewriter->create<fir::StoreOp>(loc, oldOper, cast); |
| mlir::Value load = rewriter->create<fir::LoadOp>(loc, mem); |
| rewriter->create<mlir::ReturnOp>(loc, load); |
| ret.erase(); |
| }); |
| } break; |
| case FixupTy::Codes::Split: { |
| // The FIR argument has been split into a pair of distinct arguments |
| // that are in juxtaposition to each other. (For COMPLEX value.) |
| auto newArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| if (fixup.second == 1) { |
| rewriter->setInsertionPointToStart(&func.front()); |
| auto cplxTy = oldArgTys[fixup.index - offset - fixup.second]; |
| auto undef = rewriter->create<UndefOp>(loc, cplxTy); |
| auto zero = rewriter->getIntegerAttr(rewriter->getIndexType(), 0); |
| auto one = rewriter->getIntegerAttr(rewriter->getIndexType(), 1); |
| auto cplx1 = rewriter->create<InsertValueOp>( |
| loc, cplxTy, undef, func.front().getArgument(fixup.index - 1), |
| rewriter->getArrayAttr(zero)); |
| auto cplx = rewriter->create<InsertValueOp>( |
| loc, cplxTy, cplx1, newArg, rewriter->getArrayAttr(one)); |
| func.getArgument(fixup.index + 1).replaceAllUsesWith(cplx); |
| func.front().eraseArgument(fixup.index + 1); |
| offset++; |
| } |
| } break; |
| case FixupTy::Codes::Trailing: { |
| // The FIR argument has been split into a pair of distinct arguments. |
| // The first part of the pair appears in the original argument |
| // position. The second part of the pair is appended after all the |
| // original arguments. (Boxchar arguments.) |
| auto newBufArg = |
| func.front().insertArgument(fixup.index, newInTys[fixup.index]); |
| auto newLenArg = func.front().addArgument(trailingTys[fixup.second]); |
| auto boxTy = oldArgTys[fixup.index - offset]; |
| rewriter->setInsertionPointToStart(&func.front()); |
| auto box = |
| rewriter->create<EmboxCharOp>(loc, boxTy, newBufArg, newLenArg); |
| func.getArgument(fixup.index + 1).replaceAllUsesWith(box); |
| func.front().eraseArgument(fixup.index + 1); |
| } break; |
| } |
| } |
| } |
| |
| // Set the new type and finalize the arguments, etc. |
| newInTys.insert(newInTys.end(), trailingTys.begin(), trailingTys.end()); |
| auto newFuncTy = |
| mlir::FunctionType::get(func.getContext(), newInTys, newResTys); |
| LLVM_DEBUG(llvm::dbgs() << "new func: " << newFuncTy << '\n'); |
| func.setType(newFuncTy); |
| |
| for (auto &fixup : fixups) |
| if (fixup.finalizer) |
| (*fixup.finalizer)(func); |
| } |
| |
| inline bool functionArgIsSRet(unsigned index, mlir::FuncOp func) { |
| if (auto attr = func.getArgAttrOfType<mlir::UnitAttr>(index, "llvm.sret")) |
| return true; |
| return false; |
| } |
| |
| /// Convert a complex return value. This can involve converting the return |
| /// value to a "hidden" first argument or packing the complex into a wide |
| /// GPR. |
| template <typename A, typename B, typename C> |
| void doComplexReturn(mlir::FuncOp func, A cmplx, B &newResTys, B &newInTys, |
| C &fixups) { |
| if (noComplexConversion) { |
| newResTys.push_back(cmplx); |
| return; |
| } |
| auto m = specifics->complexReturnType(cmplx.getElementType()); |
| assert(m.size() == 1); |
| auto &tup = m[0]; |
| auto attr = std::get<CodeGenSpecifics::Attributes>(tup); |
| auto argTy = std::get<mlir::Type>(tup); |
| if (attr.isSRet()) { |
| unsigned argNo = newInTys.size(); |
| fixups.emplace_back( |
| FixupTy::Codes::ReturnAsStore, argNo, [=](mlir::FuncOp func) { |
| func.setArgAttr(argNo, "llvm.sret", rewriter->getUnitAttr()); |
| }); |
| newInTys.push_back(argTy); |
| return; |
| } |
| fixups.emplace_back(FixupTy::Codes::ReturnType, newResTys.size()); |
| newResTys.push_back(argTy); |
| } |
| |
| /// Convert a complex argument value. This can involve storing the value to |
| /// a temporary memory location or factoring the value into two distinct |
| /// arguments. |
| template <typename A, typename B, typename C> |
| void doComplexArg(mlir::FuncOp func, A cmplx, B &newInTys, C &fixups) { |
| if (noComplexConversion) { |
| newInTys.push_back(cmplx); |
| return; |
| } |
| auto m = specifics->complexArgumentType(cmplx.getElementType()); |
| const auto fixupCode = |
| m.size() > 1 ? FixupTy::Codes::Split : FixupTy::Codes::ArgumentType; |
| for (auto e : llvm::enumerate(m)) { |
| auto &tup = e.value(); |
| auto index = e.index(); |
| auto attr = std::get<CodeGenSpecifics::Attributes>(tup); |
| auto argTy = std::get<mlir::Type>(tup); |
| auto argNo = newInTys.size(); |
| if (attr.isByVal()) { |
| if (auto align = attr.getAlignment()) |
| fixups.emplace_back( |
| FixupTy::Codes::ArgumentAsLoad, argNo, [=](mlir::FuncOp func) { |
| func.setArgAttr(argNo, "llvm.byval", rewriter->getUnitAttr()); |
| func.setArgAttr(argNo, "llvm.align", |
| rewriter->getIntegerAttr( |
| rewriter->getIntegerType(32), align)); |
| }); |
| else |
| fixups.emplace_back(FixupTy::Codes::ArgumentAsLoad, newInTys.size(), |
| [=](mlir::FuncOp func) { |
| func.setArgAttr(argNo, "llvm.byval", |
| rewriter->getUnitAttr()); |
| }); |
| } else { |
| if (auto align = attr.getAlignment()) |
| fixups.emplace_back(fixupCode, argNo, index, [=](mlir::FuncOp func) { |
| func.setArgAttr( |
| argNo, "llvm.align", |
| rewriter->getIntegerAttr(rewriter->getIntegerType(32), align)); |
| }); |
| else |
| fixups.emplace_back(fixupCode, argNo, index); |
| } |
| newInTys.push_back(argTy); |
| } |
| } |
| |
| private: |
| // Replace `op` and remove it. |
| void replaceOp(mlir::Operation *op, mlir::ValueRange newValues) { |
| op->replaceAllUsesWith(newValues); |
| op->dropAllReferences(); |
| op->erase(); |
| } |
| |
| inline void setMembers(CodeGenSpecifics *s, mlir::OpBuilder *r) { |
| specifics = s; |
| rewriter = r; |
| } |
| |
| inline void clearMembers() { setMembers(nullptr, nullptr); } |
| |
| CodeGenSpecifics *specifics{}; |
| mlir::OpBuilder *rewriter; |
| }; // namespace |
| } // namespace |
| |
| std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> |
| fir::createFirTargetRewritePass(const TargetRewriteOptions &options) { |
| return std::make_unique<TargetRewrite>(options); |
| } |