| //===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a translation between the MLIR LLVM dialect and LLVM IR. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Target/LLVMIR/ModuleTranslation.h" |
| |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/InlineAsm.h" |
| #include "llvm/IR/MDBuilder.h" |
| #include "llvm/IR/MatrixBuilder.h" |
| #include "llvm/IR/Operator.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| using mlir::LLVM::detail::getLLVMConstant; |
| |
| #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc" |
| |
| static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) { |
| using llvmFMF = llvm::FastMathFlags; |
| using FuncT = void (llvmFMF::*)(bool); |
| const std::pair<FastmathFlags, FuncT> handlers[] = { |
| // clang-format off |
| {FastmathFlags::nnan, &llvmFMF::setNoNaNs}, |
| {FastmathFlags::ninf, &llvmFMF::setNoInfs}, |
| {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros}, |
| {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal}, |
| {FastmathFlags::contract, &llvmFMF::setAllowContract}, |
| {FastmathFlags::afn, &llvmFMF::setApproxFunc}, |
| {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc}, |
| // clang-format on |
| }; |
| llvm::FastMathFlags ret; |
| ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue(); |
| for (auto it : handlers) |
| if (bitEnumContainsAll(fmfMlir, it.first)) |
| (ret.*(it.second))(true); |
| return ret; |
| } |
| |
| /// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices. |
| static SmallVector<unsigned> extractPosition(ArrayRef<int64_t> indices) { |
| SmallVector<unsigned> position; |
| llvm::append_range(position, indices); |
| return position; |
| } |
| |
| /// Convert an LLVM type to a string for printing in diagnostics. |
| static std::string diagStr(const llvm::Type *type) { |
| std::string str; |
| llvm::raw_string_ostream os(str); |
| type->print(os); |
| return str; |
| } |
| |
| /// Get the declaration of an overloaded llvm intrinsic. First we get the |
| /// overloaded argument types and/or result type from the CallIntrinsicOp, and |
| /// then use those to get the correct declaration of the overloaded intrinsic. |
| static FailureOr<llvm::Function *> |
| getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, |
| llvm::Module *module, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| SmallVector<llvm::Type *, 8> allArgTys; |
| for (Type type : op->getOperandTypes()) |
| allArgTys.push_back(moduleTranslation.convertType(type)); |
| |
| llvm::Type *resTy; |
| if (op.getNumResults() == 0) |
| resTy = llvm::Type::getVoidTy(module->getContext()); |
| else |
| resTy = moduleTranslation.convertType(op.getResult(0).getType()); |
| |
| // ATM we do not support variadic intrinsics. |
| llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false); |
| |
| SmallVector<llvm::Intrinsic::IITDescriptor, 8> table; |
| getIntrinsicInfoTableEntries(id, table); |
| ArrayRef<llvm::Intrinsic::IITDescriptor> tableRef = table; |
| |
| SmallVector<llvm::Type *, 8> overloadedArgTys; |
| if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef, |
| overloadedArgTys) != |
| llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) { |
| return mlir::emitError(op.getLoc(), "call intrinsic signature ") |
| << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr() |
| << " does not match any of the overloads"; |
| } |
| |
| ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys; |
| return llvm::Intrinsic::getOrInsertDeclaration(module, id, |
| overloadedArgTysRef); |
| } |
| |
| static llvm::OperandBundleDef |
| convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| std::vector<llvm::Value *> operands; |
| operands.reserve(bundleOperands.size()); |
| for (Value bundleArg : bundleOperands) |
| operands.push_back(moduleTranslation.lookupValue(bundleArg)); |
| return llvm::OperandBundleDef(bundleTag.str(), std::move(operands)); |
| } |
| |
| static SmallVector<llvm::OperandBundleDef> |
| convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| SmallVector<llvm::OperandBundleDef> bundles; |
| bundles.reserve(bundleOperands.size()); |
| |
| for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) { |
| StringRef tag = cast<StringAttr>(tagAttr).getValue(); |
| bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation)); |
| } |
| return bundles; |
| } |
| |
| static SmallVector<llvm::OperandBundleDef> |
| convertOperandBundles(OperandRangeRange bundleOperands, |
| std::optional<ArrayAttr> bundleTags, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| if (!bundleTags) |
| return {}; |
| return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); |
| } |
| |
| static LogicalResult |
| convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, |
| ArrayAttr resAttrsArray, llvm::CallBase *call, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| if (argAttrsArray) { |
| for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { |
| if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr); |
| !argAttrs.empty()) { |
| FailureOr<llvm::AttrBuilder> attrBuilder = |
| moduleTranslation.convertParameterAttrs(loc, argAttrs); |
| if (failed(attrBuilder)) |
| return failure(); |
| call->addParamAttrs(argIdx, *attrBuilder); |
| } |
| } |
| } |
| |
| if (resAttrsArray && resAttrsArray.size() > 0) { |
| if (resAttrsArray.size() != 1) |
| return mlir::emitError(loc, "llvm.func cannot have multiple results"); |
| if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); |
| !resAttrs.empty()) { |
| FailureOr<llvm::AttrBuilder> attrBuilder = |
| moduleTranslation.convertParameterAttrs(loc, resAttrs); |
| if (failed(attrBuilder)) |
| return failure(); |
| call->addRetAttrs(*attrBuilder); |
| } |
| } |
| return success(); |
| } |
| |
| static LogicalResult |
| convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| return convertParameterAndResultAttrs( |
| callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call, |
| moduleTranslation); |
| } |
| |
| /// Builder for LLVM_CallIntrinsicOp |
| static LogicalResult |
| convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Module *module = builder.GetInsertBlock()->getModule(); |
| llvm::Intrinsic::ID id = |
| llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr()); |
| if (!id) |
| return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ") |
| << op.getIntrinAttr(); |
| |
| llvm::Function *fn = nullptr; |
| if (llvm::Intrinsic::isOverloaded(id)) { |
| auto fnOrFailure = |
| getOverloadedDeclaration(op, id, module, moduleTranslation); |
| if (failed(fnOrFailure)) |
| return failure(); |
| fn = *fnOrFailure; |
| } else { |
| fn = llvm::Intrinsic::getOrInsertDeclaration(module, id, {}); |
| } |
| |
| // Check the result type of the call. |
| const llvm::Type *intrinType = |
| op.getNumResults() == 0 |
| ? llvm::Type::getVoidTy(module->getContext()) |
| : moduleTranslation.convertType(op.getResultTypes().front()); |
| if (intrinType != fn->getReturnType()) { |
| return mlir::emitError(op.getLoc(), "intrinsic call returns ") |
| << diagStr(intrinType) << " but " << op.getIntrinAttr() |
| << " actually returns " << diagStr(fn->getReturnType()); |
| } |
| |
| // Check the argument types of the call. If the function is variadic, check |
| // the subrange of required arguments. |
| if (!fn->getFunctionType()->isVarArg() && |
| op.getArgs().size() != fn->arg_size()) { |
| return mlir::emitError(op.getLoc(), "intrinsic call has ") |
| << op.getArgs().size() << " operands but " << op.getIntrinAttr() |
| << " expects " << fn->arg_size(); |
| } |
| if (fn->getFunctionType()->isVarArg() && |
| op.getArgs().size() < fn->arg_size()) { |
| return mlir::emitError(op.getLoc(), "intrinsic call has ") |
| << op.getArgs().size() << " operands but variadic " |
| << op.getIntrinAttr() << " expects at least " << fn->arg_size(); |
| } |
| // Check the arguments up to the number the function requires. |
| for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) { |
| const llvm::Type *expected = fn->getArg(i)->getType(); |
| const llvm::Type *actual = |
| moduleTranslation.convertType(op.getOperandTypes()[i]); |
| if (actual != expected) { |
| return mlir::emitError(op.getLoc(), "intrinsic call operand #") |
| << i << " has type " << diagStr(actual) << " but " |
| << op.getIntrinAttr() << " expects " << diagStr(expected); |
| } |
| } |
| |
| FastmathFlagsInterface itf = op; |
| builder.setFastMathFlags(getFastmathFlags(itf)); |
| |
| auto *inst = builder.CreateCall( |
| fn, moduleTranslation.lookupValues(op.getArgs()), |
| convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), |
| moduleTranslation)); |
| |
| if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(), |
| op.getResAttrsAttr(), inst, |
| moduleTranslation))) |
| return failure(); |
| |
| if (op.getNumResults() == 1) |
| moduleTranslation.mapValue(op->getResults().front()) = inst; |
| return success(); |
| } |
| |
| static void convertLinkerOptionsOp(ArrayAttr options, |
| llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); |
| llvm::LLVMContext &context = llvmModule->getContext(); |
| llvm::NamedMDNode *linkerMDNode = |
| llvmModule->getOrInsertNamedMetadata("llvm.linker.options"); |
| SmallVector<llvm::Metadata *> MDNodes; |
| MDNodes.reserve(options.size()); |
| for (auto s : options.getAsRange<StringAttr>()) { |
| auto *MDNode = llvm::MDString::get(context, s.getValue()); |
| MDNodes.push_back(MDNode); |
| } |
| |
| auto *listMDNode = llvm::MDTuple::get(context, MDNodes); |
| linkerMDNode->addOperand(listMDNode); |
| } |
| |
| static llvm::Metadata * |
| convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, |
| llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::LLVMContext &context = builder.getContext(); |
| llvm::MDBuilder mdb(context); |
| SmallVector<llvm::Metadata *> nodes; |
| |
| if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) { |
| for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) { |
| llvm::Function *fromFn = |
| moduleTranslation.lookupFunction(entry.getFrom().getValue()); |
| llvm::Function *toFn = |
| moduleTranslation.lookupFunction(entry.getTo().getValue()); |
| llvm::Metadata *vals[] = { |
| llvm::ValueAsMetadata::get(fromFn), llvm::ValueAsMetadata::get(toFn), |
| mdb.createConstant(llvm::ConstantInt::get( |
| llvm::Type::getInt64Ty(context), entry.getCount()))}; |
| nodes.push_back(llvm::MDNode::get(context, vals)); |
| } |
| return llvm::MDTuple::getDistinct(context, nodes); |
| } |
| return nullptr; |
| } |
| |
| static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); |
| for (auto flagAttr : flags.getAsRange<ModuleFlagAttr>()) { |
| llvm::Metadata *valueMetadata = |
| llvm::TypeSwitch<Attribute, llvm::Metadata *>(flagAttr.getValue()) |
| .Case<StringAttr>([&](auto strAttr) { |
| return llvm::MDString::get(builder.getContext(), |
| strAttr.getValue()); |
| }) |
| .Case<IntegerAttr>([&](auto intAttr) { |
| return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( |
| llvm::Type::getInt32Ty(builder.getContext()), |
| intAttr.getInt())); |
| }) |
| .Case<ArrayAttr>([&](auto arrayAttr) { |
| return convertModuleFlagValue(flagAttr.getKey().getValue(), |
| arrayAttr, builder, |
| moduleTranslation); |
| }) |
| .Default([](auto) { return nullptr; }); |
| |
| assert(valueMetadata && "expected valid metadata"); |
| llvmModule->addModuleFlag( |
| convertModFlagBehaviorToLLVM(flagAttr.getBehavior()), |
| flagAttr.getKey().getValue(), valueMetadata); |
| } |
| } |
| |
| static LogicalResult |
| convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) { |
| |
| llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder); |
| if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst)) |
| builder.setFastMathFlags(getFastmathFlags(fmf)); |
| |
| #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" |
| #include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc" |
| |
| // Emit function calls. If the "callee" attribute is present, this is a |
| // direct function call and we also need to look up the remapped function |
| // itself. Otherwise, this is an indirect call and the callee is the first |
| // operand, look it up as a normal value. |
| if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) { |
| auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands()); |
| SmallVector<llvm::OperandBundleDef> opBundles = |
| convertOperandBundles(callOp.getOpBundleOperands(), |
| callOp.getOpBundleTags(), moduleTranslation); |
| ArrayRef<llvm::Value *> operandsRef(operands); |
| llvm::CallInst *call; |
| if (auto attr = callOp.getCalleeAttr()) { |
| call = |
| builder.CreateCall(moduleTranslation.lookupFunction(attr.getValue()), |
| operandsRef, opBundles); |
| } else { |
| llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>( |
| moduleTranslation.convertType(callOp.getCalleeFunctionType())); |
| call = builder.CreateCall(calleeType, operandsRef.front(), |
| operandsRef.drop_front(), opBundles); |
| } |
| call->setCallingConv(convertCConvToLLVM(callOp.getCConv())); |
| call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind())); |
| if (callOp.getConvergentAttr()) |
| call->addFnAttr(llvm::Attribute::Convergent); |
| if (callOp.getNoUnwindAttr()) |
| call->addFnAttr(llvm::Attribute::NoUnwind); |
| if (callOp.getWillReturnAttr()) |
| call->addFnAttr(llvm::Attribute::WillReturn); |
| if (callOp.getNoInlineAttr()) |
| call->addFnAttr(llvm::Attribute::NoInline); |
| if (callOp.getAlwaysInlineAttr()) |
| call->addFnAttr(llvm::Attribute::AlwaysInline); |
| if (callOp.getInlineHintAttr()) |
| call->addFnAttr(llvm::Attribute::InlineHint); |
| |
| if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) |
| return failure(); |
| |
| if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { |
| llvm::MemoryEffects memEffects = |
| llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem, |
| convertModRefInfoToLLVM(memAttr.getArgMem())) | |
| llvm::MemoryEffects( |
| llvm::MemoryEffects::Location::InaccessibleMem, |
| convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) | |
| llvm::MemoryEffects(llvm::MemoryEffects::Location::Other, |
| convertModRefInfoToLLVM(memAttr.getOther())); |
| call->setMemoryEffects(memEffects); |
| } |
| |
| moduleTranslation.setAccessGroupsMetadata(callOp, call); |
| moduleTranslation.setAliasScopeMetadata(callOp, call); |
| moduleTranslation.setTBAAMetadata(callOp, call); |
| // If the called function has a result, remap the corresponding value. Note |
| // that LLVM IR dialect CallOp has either 0 or 1 result. |
| if (opInst.getNumResults() != 0) |
| moduleTranslation.mapValue(opInst.getResult(0), call); |
| // Check that LLVM call returns void for 0-result functions. |
| else if (!call->getType()->isVoidTy()) |
| return failure(); |
| moduleTranslation.mapCall(callOp, call); |
| return success(); |
| } |
| |
| if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) { |
| // TODO: refactor function type creation which usually occurs in std-LLVM |
| // conversion. |
| SmallVector<Type, 8> operandTypes; |
| llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes()); |
| |
| Type resultType; |
| if (inlineAsmOp.getNumResults() == 0) { |
| resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext()); |
| } else { |
| assert(inlineAsmOp.getNumResults() == 1); |
| resultType = inlineAsmOp.getResultTypes()[0]; |
| } |
| auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); |
| llvm::InlineAsm *inlineAsmInst = |
| inlineAsmOp.getAsmDialect() |
| ? llvm::InlineAsm::get( |
| static_cast<llvm::FunctionType *>( |
| moduleTranslation.convertType(ft)), |
| inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(), |
| inlineAsmOp.getHasSideEffects(), |
| inlineAsmOp.getIsAlignStack(), |
| convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect())) |
| : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>( |
| moduleTranslation.convertType(ft)), |
| inlineAsmOp.getAsmString(), |
| inlineAsmOp.getConstraints(), |
| inlineAsmOp.getHasSideEffects(), |
| inlineAsmOp.getIsAlignStack()); |
| llvm::CallInst *inst = builder.CreateCall( |
| inlineAsmInst, |
| moduleTranslation.lookupValues(inlineAsmOp.getOperands())); |
| if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) { |
| llvm::AttributeList attrList; |
| for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) { |
| Attribute attr = it.value(); |
| if (!attr) |
| continue; |
| DictionaryAttr dAttr = cast<DictionaryAttr>(attr); |
| TypeAttr tAttr = |
| cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName())); |
| llvm::AttrBuilder b(moduleTranslation.getLLVMContext()); |
| llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue()); |
| b.addTypeAttr(llvm::Attribute::ElementType, ty); |
| // shift to account for the returned value (this is always 1 aggregate |
| // value in LLVM). |
| int shift = (opInst.getNumResults() > 0) ? 1 : 0; |
| attrList = attrList.addAttributesAtIndex( |
| moduleTranslation.getLLVMContext(), it.index() + shift, b); |
| } |
| inst->setAttributes(attrList); |
| } |
| |
| if (opInst.getNumResults() != 0) |
| moduleTranslation.mapValue(opInst.getResult(0), inst); |
| return success(); |
| } |
| |
| if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) { |
| auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands()); |
| SmallVector<llvm::OperandBundleDef> opBundles = |
| convertOperandBundles(invOp.getOpBundleOperands(), |
| invOp.getOpBundleTags(), moduleTranslation); |
| ArrayRef<llvm::Value *> operandsRef(operands); |
| llvm::InvokeInst *result; |
| if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) { |
| result = builder.CreateInvoke( |
| moduleTranslation.lookupFunction(attr.getValue()), |
| moduleTranslation.lookupBlock(invOp.getSuccessor(0)), |
| moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef, |
| opBundles); |
| } else { |
| llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>( |
| moduleTranslation.convertType(invOp.getCalleeFunctionType())); |
| result = builder.CreateInvoke( |
| calleeType, operandsRef.front(), |
| moduleTranslation.lookupBlock(invOp.getSuccessor(0)), |
| moduleTranslation.lookupBlock(invOp.getSuccessor(1)), |
| operandsRef.drop_front(), opBundles); |
| } |
| result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); |
| if (failed( |
| convertParameterAndResultAttrs(invOp, result, moduleTranslation))) |
| return failure(); |
| moduleTranslation.mapBranch(invOp, result); |
| // InvokeOp can only have 0 or 1 result |
| if (invOp->getNumResults() != 0) { |
| moduleTranslation.mapValue(opInst.getResult(0), result); |
| return success(); |
| } |
| return success(result->getType()->isVoidTy()); |
| } |
| |
| if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) { |
| llvm::Type *ty = moduleTranslation.convertType(lpOp.getType()); |
| llvm::LandingPadInst *lpi = |
| builder.CreateLandingPad(ty, lpOp.getNumOperands()); |
| lpi->setCleanup(lpOp.getCleanup()); |
| |
| // Add clauses |
| for (llvm::Value *operand : |
| moduleTranslation.lookupValues(lpOp.getOperands())) { |
| // All operands should be constant - checked by verifier |
| if (auto *constOperand = dyn_cast<llvm::Constant>(operand)) |
| lpi->addClause(constOperand); |
| } |
| moduleTranslation.mapValue(lpOp.getResult(), lpi); |
| return success(); |
| } |
| |
| // Emit branches. We need to look up the remapped blocks and ignore the |
| // block arguments that were transformed into PHI nodes. |
| if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) { |
| llvm::BranchInst *branch = |
| builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); |
| moduleTranslation.mapBranch(&opInst, branch); |
| moduleTranslation.setLoopMetadata(&opInst, branch); |
| return success(); |
| } |
| if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) { |
| llvm::BranchInst *branch = builder.CreateCondBr( |
| moduleTranslation.lookupValue(condbrOp.getOperand(0)), |
| moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)), |
| moduleTranslation.lookupBlock(condbrOp.getSuccessor(1))); |
| moduleTranslation.mapBranch(&opInst, branch); |
| moduleTranslation.setLoopMetadata(&opInst, branch); |
| return success(); |
| } |
| if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) { |
| llvm::SwitchInst *switchInst = builder.CreateSwitch( |
| moduleTranslation.lookupValue(switchOp.getValue()), |
| moduleTranslation.lookupBlock(switchOp.getDefaultDestination()), |
| switchOp.getCaseDestinations().size()); |
| |
| // Handle switch with zero cases. |
| if (!switchOp.getCaseValues()) |
| return success(); |
| |
| auto *ty = llvm::cast<llvm::IntegerType>( |
| moduleTranslation.convertType(switchOp.getValue().getType())); |
| for (auto i : |
| llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()), |
| switchOp.getCaseDestinations())) |
| switchInst->addCase( |
| llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()), |
| moduleTranslation.lookupBlock(std::get<1>(i))); |
| |
| moduleTranslation.mapBranch(&opInst, switchInst); |
| return success(); |
| } |
| if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(opInst)) { |
| llvm::IndirectBrInst *indBr = builder.CreateIndirectBr( |
| moduleTranslation.lookupValue(indBrOp.getAddr()), |
| indBrOp->getNumSuccessors()); |
| for (auto *succ : indBrOp.getSuccessors()) |
| indBr->addDestination(moduleTranslation.lookupBlock(succ)); |
| moduleTranslation.mapBranch(&opInst, indBr); |
| return success(); |
| } |
| |
| // Emit addressof. We need to look up the global value referenced by the |
| // operation and store it in the MLIR-to-LLVM value mapping. This does not |
| // emit any LLVM instruction. |
| if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) { |
| LLVM::GlobalOp global = |
| addressOfOp.getGlobal(moduleTranslation.symbolTable()); |
| LLVM::LLVMFuncOp function = |
| addressOfOp.getFunction(moduleTranslation.symbolTable()); |
| LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable()); |
| |
| // The verifier should not have allowed this. |
| assert((global || function || alias) && |
| "referencing an undefined global, function, or alias"); |
| |
| llvm::Value *llvmValue = nullptr; |
| if (global) |
| llvmValue = moduleTranslation.lookupGlobal(global); |
| else if (alias) |
| llvmValue = moduleTranslation.lookupAlias(alias); |
| else |
| llvmValue = moduleTranslation.lookupFunction(function.getName()); |
| |
| moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue); |
| return success(); |
| } |
| |
| // Emit dso_local_equivalent. We need to look up the global value referenced |
| // by the operation and store it in the MLIR-to-LLVM value mapping. |
| if (auto dsoLocalEquivalentOp = |
| dyn_cast<LLVM::DSOLocalEquivalentOp>(opInst)) { |
| LLVM::LLVMFuncOp function = |
| dsoLocalEquivalentOp.getFunction(moduleTranslation.symbolTable()); |
| LLVM::AliasOp alias = |
| dsoLocalEquivalentOp.getAlias(moduleTranslation.symbolTable()); |
| |
| // The verifier should not have allowed this. |
| assert((function || alias) && |
| "referencing an undefined function, or alias"); |
| |
| llvm::Value *llvmValue = nullptr; |
| if (alias) |
| llvmValue = moduleTranslation.lookupAlias(alias); |
| else |
| llvmValue = moduleTranslation.lookupFunction(function.getName()); |
| |
| moduleTranslation.mapValue( |
| dsoLocalEquivalentOp.getResult(), |
| llvm::DSOLocalEquivalent::get(cast<llvm::GlobalValue>(llvmValue))); |
| return success(); |
| } |
| |
| // Emit blockaddress. We first need to find the LLVM block referenced by this |
| // operation and then create a LLVM block address for it. |
| if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) { |
| // getBlockTagOp() walks a function to search for block labels. Check |
| // whether it's in cache first. |
| BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr(); |
| BlockTagOp blockTagOp = moduleTranslation.lookupBlockTag(blockAddressAttr); |
| if (!blockTagOp) { |
| blockTagOp = blockAddressOp.getBlockTagOp(); |
| moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp); |
| } |
| |
| llvm::Value *llvmValue = nullptr; |
| StringRef fnName = blockAddressAttr.getFunction().getValue(); |
| if (llvm::BasicBlock *llvmBlock = |
| moduleTranslation.lookupBlock(blockTagOp->getBlock())) { |
| llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName); |
| llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock); |
| } else { |
| // The matching LLVM block is not yet emitted, a placeholder is created |
| // in its place. When the LLVM block is emitted later in translation, |
| // the llvmValue is replaced with the actual llvm::BlockAddress. |
| // A GlobalVariable is chosen as placeholder because in general LLVM |
| // constants are uniqued and are not proper for RAUW, since that could |
| // harm unrelated uses of the constant. |
| llvmValue = new llvm::GlobalVariable( |
| *moduleTranslation.getLLVMModule(), |
| llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()), |
| /*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage, |
| /*Initializer=*/nullptr, |
| Twine("__mlir_block_address_") |
| .concat(Twine(fnName)) |
| .concat(Twine((uint64_t)blockAddressOp.getOperation()))); |
| moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue); |
| } |
| |
| moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue); |
| return success(); |
| } |
| |
| // Emit block label. If this label is seen before BlockAddressOp is |
| // translated, go ahead and already map it. |
| if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) { |
| auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>(); |
| BlockAddressAttr blockAddressAttr = BlockAddressAttr::get( |
| &moduleTranslation.getContext(), |
| FlatSymbolRefAttr::get(&moduleTranslation.getContext(), |
| funcOp.getName()), |
| blockTagOp.getTag()); |
| moduleTranslation.mapBlockTag(blockAddressAttr, blockTagOp); |
| return success(); |
| } |
| |
| return failure(); |
| } |
| |
| namespace { |
| /// Implementation of the dialect interface that converts operations belonging |
| /// to the LLVM dialect to LLVM IR. |
| class LLVMDialectLLVMIRTranslationInterface |
| : public LLVMTranslationDialectInterface { |
| public: |
| using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; |
| |
| /// Translates the given operation to LLVM IR using the provided IR builder |
| /// and saving the state in `moduleTranslation`. |
| LogicalResult |
| convertOperation(Operation *op, llvm::IRBuilderBase &builder, |
| LLVM::ModuleTranslation &moduleTranslation) const final { |
| return convertOperationImpl(*op, builder, moduleTranslation); |
| } |
| }; |
| } // namespace |
| |
| void mlir::registerLLVMDialectTranslation(DialectRegistry ®istry) { |
| registry.insert<LLVM::LLVMDialect>(); |
| registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { |
| dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>(); |
| }); |
| } |
| |
| void mlir::registerLLVMDialectTranslation(MLIRContext &context) { |
| DialectRegistry registry; |
| registerLLVMDialectTranslation(registry); |
| context.appendDialectRegistry(registry); |
| } |