blob: cc8800df0095fa8f041a3c003222dd59e2b66441 [file] [log] [blame]
//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
//
//===----------------------------------------------------------------------===//
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/MatrixBuilder.h"
#include "llvm/IR/Operator.h"
using namespace mlir;
using namespace mlir::LLVM;
using mlir::LLVM::detail::getLLVMConstant;
#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
/// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
static llvm::CmpInst::Predicate getLLVMCmpPredicate(ICmpPredicate p) {
switch (p) {
case LLVM::ICmpPredicate::eq:
return llvm::CmpInst::Predicate::ICMP_EQ;
case LLVM::ICmpPredicate::ne:
return llvm::CmpInst::Predicate::ICMP_NE;
case LLVM::ICmpPredicate::slt:
return llvm::CmpInst::Predicate::ICMP_SLT;
case LLVM::ICmpPredicate::sle:
return llvm::CmpInst::Predicate::ICMP_SLE;
case LLVM::ICmpPredicate::sgt:
return llvm::CmpInst::Predicate::ICMP_SGT;
case LLVM::ICmpPredicate::sge:
return llvm::CmpInst::Predicate::ICMP_SGE;
case LLVM::ICmpPredicate::ult:
return llvm::CmpInst::Predicate::ICMP_ULT;
case LLVM::ICmpPredicate::ule:
return llvm::CmpInst::Predicate::ICMP_ULE;
case LLVM::ICmpPredicate::ugt:
return llvm::CmpInst::Predicate::ICMP_UGT;
case LLVM::ICmpPredicate::uge:
return llvm::CmpInst::Predicate::ICMP_UGE;
}
llvm_unreachable("incorrect comparison predicate");
}
static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
switch (p) {
case LLVM::FCmpPredicate::_false:
return llvm::CmpInst::Predicate::FCMP_FALSE;
case LLVM::FCmpPredicate::oeq:
return llvm::CmpInst::Predicate::FCMP_OEQ;
case LLVM::FCmpPredicate::ogt:
return llvm::CmpInst::Predicate::FCMP_OGT;
case LLVM::FCmpPredicate::oge:
return llvm::CmpInst::Predicate::FCMP_OGE;
case LLVM::FCmpPredicate::olt:
return llvm::CmpInst::Predicate::FCMP_OLT;
case LLVM::FCmpPredicate::ole:
return llvm::CmpInst::Predicate::FCMP_OLE;
case LLVM::FCmpPredicate::one:
return llvm::CmpInst::Predicate::FCMP_ONE;
case LLVM::FCmpPredicate::ord:
return llvm::CmpInst::Predicate::FCMP_ORD;
case LLVM::FCmpPredicate::ueq:
return llvm::CmpInst::Predicate::FCMP_UEQ;
case LLVM::FCmpPredicate::ugt:
return llvm::CmpInst::Predicate::FCMP_UGT;
case LLVM::FCmpPredicate::uge:
return llvm::CmpInst::Predicate::FCMP_UGE;
case LLVM::FCmpPredicate::ult:
return llvm::CmpInst::Predicate::FCMP_ULT;
case LLVM::FCmpPredicate::ule:
return llvm::CmpInst::Predicate::FCMP_ULE;
case LLVM::FCmpPredicate::une:
return llvm::CmpInst::Predicate::FCMP_UNE;
case LLVM::FCmpPredicate::uno:
return llvm::CmpInst::Predicate::FCMP_UNO;
case LLVM::FCmpPredicate::_true:
return llvm::CmpInst::Predicate::FCMP_TRUE;
}
llvm_unreachable("incorrect comparison predicate");
}
static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
switch (op) {
case LLVM::AtomicBinOp::xchg:
return llvm::AtomicRMWInst::BinOp::Xchg;
case LLVM::AtomicBinOp::add:
return llvm::AtomicRMWInst::BinOp::Add;
case LLVM::AtomicBinOp::sub:
return llvm::AtomicRMWInst::BinOp::Sub;
case LLVM::AtomicBinOp::_and:
return llvm::AtomicRMWInst::BinOp::And;
case LLVM::AtomicBinOp::nand:
return llvm::AtomicRMWInst::BinOp::Nand;
case LLVM::AtomicBinOp::_or:
return llvm::AtomicRMWInst::BinOp::Or;
case LLVM::AtomicBinOp::_xor:
return llvm::AtomicRMWInst::BinOp::Xor;
case LLVM::AtomicBinOp::max:
return llvm::AtomicRMWInst::BinOp::Max;
case LLVM::AtomicBinOp::min:
return llvm::AtomicRMWInst::BinOp::Min;
case LLVM::AtomicBinOp::umax:
return llvm::AtomicRMWInst::BinOp::UMax;
case LLVM::AtomicBinOp::umin:
return llvm::AtomicRMWInst::BinOp::UMin;
case LLVM::AtomicBinOp::fadd:
return llvm::AtomicRMWInst::BinOp::FAdd;
case LLVM::AtomicBinOp::fsub:
return llvm::AtomicRMWInst::BinOp::FSub;
}
llvm_unreachable("incorrect atomic binary operator");
}
static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
switch (ordering) {
case LLVM::AtomicOrdering::not_atomic:
return llvm::AtomicOrdering::NotAtomic;
case LLVM::AtomicOrdering::unordered:
return llvm::AtomicOrdering::Unordered;
case LLVM::AtomicOrdering::monotonic:
return llvm::AtomicOrdering::Monotonic;
case LLVM::AtomicOrdering::acquire:
return llvm::AtomicOrdering::Acquire;
case LLVM::AtomicOrdering::release:
return llvm::AtomicOrdering::Release;
case LLVM::AtomicOrdering::acq_rel:
return llvm::AtomicOrdering::AcquireRelease;
case LLVM::AtomicOrdering::seq_cst:
return llvm::AtomicOrdering::SequentiallyConsistent;
}
llvm_unreachable("incorrect atomic ordering");
}
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
using llvmFMF = llvm::FastMathFlags;
using FuncT = void (llvmFMF::*)(bool);
const std::pair<FastmathFlags, FuncT> handlers[] = {
// clang-format off
{FastmathFlags::nnan, &llvmFMF::setNoNaNs},
{FastmathFlags::ninf, &llvmFMF::setNoInfs},
{FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
{FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
{FastmathFlags::contract, &llvmFMF::setAllowContract},
{FastmathFlags::afn, &llvmFMF::setApproxFunc},
{FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
{FastmathFlags::fast, &llvmFMF::setFast},
// clang-format on
};
llvm::FastMathFlags ret;
auto fmf = op.getFastmathFlags();
for (auto it : handlers)
if (bitEnumContains(fmf, it.first))
(ret.*(it.second))(true);
return ret;
}
/// Returns an LLVM metadata node corresponding to a loop option. This metadata
/// is attached to an llvm.loop node.
static llvm::MDNode *getLoopOptionMetadata(llvm::LLVMContext &ctx,
LoopOptionCase option,
int64_t value) {
StringRef name;
llvm::Constant *cstValue = nullptr;
switch (option) {
case LoopOptionCase::disable_licm:
name = "llvm.licm.disable";
cstValue = llvm::ConstantInt::getBool(ctx, value);
break;
case LoopOptionCase::disable_unroll:
name = "llvm.loop.unroll.disable";
cstValue = llvm::ConstantInt::getBool(ctx, value);
break;
case LoopOptionCase::interleave_count:
name = "llvm.loop.interleave.count";
cstValue = llvm::ConstantInt::get(
llvm::IntegerType::get(ctx, /*NumBits=*/32), value);
break;
case LoopOptionCase::disable_pipeline:
name = "llvm.loop.pipeline.disable";
cstValue = llvm::ConstantInt::getBool(ctx, value);
break;
case LoopOptionCase::pipeline_initiation_interval:
name = "llvm.loop.pipeline.initiationinterval";
cstValue = llvm::ConstantInt::get(
llvm::IntegerType::get(ctx, /*NumBits=*/32), value);
break;
}
return llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name),
llvm::ConstantAsMetadata::get(cstValue)});
}
static void setLoopMetadata(Operation &opInst, llvm::Instruction &llvmInst,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
if (Attribute attr = opInst.getAttr(LLVMDialect::getLoopAttrName())) {
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::MDNode *loopMD = moduleTranslation.lookupLoopOptionsMetadata(attr);
if (!loopMD) {
llvm::LLVMContext &ctx = module->getContext();
SmallVector<llvm::Metadata *> loopOptions;
// Reserve operand 0 for loop id self reference.
auto dummy = llvm::MDNode::getTemporary(ctx, llvm::None);
loopOptions.push_back(dummy.get());
auto loopAttr = attr.cast<DictionaryAttr>();
auto parallelAccessGroup =
loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName());
if (parallelAccessGroup.hasValue()) {
SmallVector<llvm::Metadata *> parallelAccess;
parallelAccess.push_back(
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
for (SymbolRefAttr accessGroupRef : parallelAccessGroup->getValue()
.cast<ArrayAttr>()
.getAsRange<SymbolRefAttr>())
parallelAccess.push_back(
moduleTranslation.getAccessGroup(opInst, accessGroupRef));
loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess));
}
if (auto loopOptionsAttr = loopAttr.getAs<LoopOptionsAttr>(
LLVMDialect::getLoopOptionsAttrName())) {
for (auto option : loopOptionsAttr.getOptions())
loopOptions.push_back(
getLoopOptionMetadata(ctx, option.first, option.second));
}
// Create loop options and set the first operand to itself.
loopMD = llvm::MDNode::get(ctx, loopOptions);
loopMD->replaceOperandWith(0, loopMD);
// Store a map from this Attribute to the LLVM metadata in case we
// encounter it again.
moduleTranslation.mapLoopOptionsMetadata(attr, loopMD);
}
llvmInst.setMetadata(module->getMDKindID("llvm.loop"), loopMD);
}
}
static LogicalResult
convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
auto extractPosition = [](ArrayAttr attr) {
SmallVector<unsigned, 4> position;
position.reserve(attr.size());
for (Attribute v : attr)
position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
return position;
};
llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
builder.setFastMathFlags(getFastmathFlags(fmf));
#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
// Emit function calls. If the "callee" attribute is present, this is a
// direct function call and we also need to look up the remapped function
// itself. Otherwise, this is an indirect call and the callee is the first
// operand, look it up as a normal value. Return the llvm::Value representing
// the function result, which may be of llvm::VoidTy type.
auto convertCall = [&](Operation &op) -> llvm::Value * {
auto operands = moduleTranslation.lookupValues(op.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
if (auto attr = op.getAttrOfType<FlatSymbolRefAttr>("callee"))
return builder.CreateCall(
moduleTranslation.lookupFunction(attr.getValue()), operandsRef);
auto *calleePtrType =
cast<llvm::PointerType>(operandsRef.front()->getType());
auto *calleeType =
cast<llvm::FunctionType>(calleePtrType->getElementType());
return builder.CreateCall(calleeType, operandsRef.front(),
operandsRef.drop_front());
};
// Emit calls. If the called function has a result, remap the corresponding
// value. Note that LLVM IR dialect CallOp has either 0 or 1 result.
if (isa<LLVM::CallOp>(opInst)) {
llvm::Value *result = convertCall(opInst);
if (opInst.getNumResults() != 0) {
moduleTranslation.mapValue(opInst.getResult(0), result);
return success();
}
// Check that LLVM call returns void for 0-result functions.
return success(result->getType()->isVoidTy());
}
if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
// TODO: refactor function type creation which usually occurs in std-LLVM
// conversion.
SmallVector<Type, 8> operandTypes;
operandTypes.reserve(inlineAsmOp.getOperands().size());
for (auto t : inlineAsmOp.getOperands().getTypes())
operandTypes.push_back(t);
Type resultType;
if (inlineAsmOp.getNumResults() == 0) {
resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
} else {
assert(inlineAsmOp.getNumResults() == 1);
resultType = inlineAsmOp.getResultTypes()[0];
}
auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
llvm::InlineAsm *inlineAsmInst =
inlineAsmOp.getAsmDialect().hasValue()
? llvm::InlineAsm::get(
static_cast<llvm::FunctionType *>(
moduleTranslation.convertType(ft)),
inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
inlineAsmOp.getHasSideEffects(),
inlineAsmOp.getIsAlignStack(),
convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
: llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
moduleTranslation.convertType(ft)),
inlineAsmOp.getAsmString(),
inlineAsmOp.getConstraints(),
inlineAsmOp.getHasSideEffects(),
inlineAsmOp.getIsAlignStack());
llvm::Value *result = builder.CreateCall(
inlineAsmInst,
moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
if (opInst.getNumResults() != 0)
moduleTranslation.mapValue(opInst.getResult(0), result);
return success();
}
if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
ArrayRef<llvm::Value *> operandsRef(operands);
if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
builder.CreateInvoke(moduleTranslation.lookupFunction(attr.getValue()),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef);
} else {
auto *calleePtrType =
cast<llvm::PointerType>(operandsRef.front()->getType());
auto *calleeType =
cast<llvm::FunctionType>(calleePtrType->getElementType());
builder.CreateInvoke(calleeType, operandsRef.front(),
moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
operandsRef.drop_front());
}
return success();
}
if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
llvm::Type *ty = moduleTranslation.convertType(lpOp.getType());
llvm::LandingPadInst *lpi =
builder.CreateLandingPad(ty, lpOp.getNumOperands());
// Add clauses
for (llvm::Value *operand :
moduleTranslation.lookupValues(lpOp.getOperands())) {
// All operands should be constant - checked by verifier
if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
lpi->addClause(constOperand);
}
moduleTranslation.mapValue(lpOp.getResult(), lpi);
return success();
}
// Emit branches. We need to look up the remapped blocks and ignore the block
// arguments that were transformed into PHI nodes.
if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
llvm::BranchInst *branch =
builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
moduleTranslation.mapBranch(&opInst, branch);
setLoopMetadata(opInst, *branch, builder, moduleTranslation);
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
llvm::MDNode *branchWeights = nullptr;
if (auto weights = condbrOp.getBranchWeights()) {
// Map weight attributes to LLVM metadata.
auto weightValues = weights->getValues<APInt>();
auto trueWeight = weightValues[0].getSExtValue();
auto falseWeight = weightValues[1].getSExtValue();
branchWeights =
llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(static_cast<uint32_t>(trueWeight),
static_cast<uint32_t>(falseWeight));
}
llvm::BranchInst *branch = builder.CreateCondBr(
moduleTranslation.lookupValue(condbrOp.getOperand(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)), branchWeights);
moduleTranslation.mapBranch(&opInst, branch);
setLoopMetadata(opInst, *branch, builder, moduleTranslation);
return success();
}
if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
llvm::MDNode *branchWeights = nullptr;
if (auto weights = switchOp.getBranchWeights()) {
llvm::SmallVector<uint32_t> weightValues;
weightValues.reserve(weights->size());
for (llvm::APInt weight : weights->cast<DenseIntElementsAttr>())
weightValues.push_back(weight.getLimitedValue());
branchWeights = llvm::MDBuilder(moduleTranslation.getLLVMContext())
.createBranchWeights(weightValues);
}
llvm::SwitchInst *switchInst = builder.CreateSwitch(
moduleTranslation.lookupValue(switchOp.getValue()),
moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
switchOp.getCaseDestinations().size(), branchWeights);
auto *ty = llvm::cast<llvm::IntegerType>(
moduleTranslation.convertType(switchOp.getValue().getType()));
for (auto i :
llvm::zip(switchOp.getCaseValues()->cast<DenseIntElementsAttr>(),
switchOp.getCaseDestinations()))
switchInst->addCase(
llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
moduleTranslation.lookupBlock(std::get<1>(i)));
moduleTranslation.mapBranch(&opInst, switchInst);
return success();
}
// Emit addressof. We need to look up the global value referenced by the
// operation and store it in the MLIR-to-LLVM value mapping. This does not
// emit any LLVM instruction.
if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
LLVM::GlobalOp global = addressOfOp.getGlobal();
LLVM::LLVMFuncOp function = addressOfOp.getFunction();
// The verifier should not have allowed this.
assert((global || function) &&
"referencing an undefined global or function");
moduleTranslation.mapValue(
addressOfOp.getResult(),
global ? moduleTranslation.lookupGlobal(global)
: moduleTranslation.lookupFunction(function.getName()));
return success();
}
return failure();
}
namespace {
/// Implementation of the dialect interface that converts operations belonging
/// to the LLVM dialect to LLVM IR.
class LLVMDialectLLVMIRTranslationInterface
: public LLVMTranslationDialectInterface {
public:
using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
/// Translates the given operation to LLVM IR using the provided IR builder
/// and saving the state in `moduleTranslation`.
LogicalResult
convertOperation(Operation *op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) const final {
return convertOperationImpl(*op, builder, moduleTranslation);
}
};
} // end namespace
void mlir::registerLLVMDialectTranslation(DialectRegistry &registry) {
registry.insert<LLVM::LLVMDialect>();
registry.addDialectInterface<LLVM::LLVMDialect,
LLVMDialectLLVMIRTranslationInterface>();
}
void mlir::registerLLVMDialectTranslation(MLIRContext &context) {
DialectRegistry registry;
registerLLVMDialectTranslation(registry);
context.appendDialectRegistry(registry);
}