| //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements lowering of CIR operations to LLVMIR. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "LowerToLLVM.h" |
| |
| #include <deque> |
| #include <optional> |
| |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Dialect/DLTI/DLTI.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinDialect.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
| #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
| #include "mlir/Target/LLVMIR/Export.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "clang/CIR/Dialect/IR/CIRAttrs.h" |
| #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| #include "clang/CIR/Dialect/Passes.h" |
| #include "clang/CIR/LoweringHelpers.h" |
| #include "clang/CIR/MissingFeatures.h" |
| #include "clang/CIR/Passes.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/IR/Module.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include "llvm/Support/TimeProfiler.h" |
| |
| using namespace cir; |
| using namespace llvm; |
| |
| namespace cir { |
| namespace direct { |
| |
| //===----------------------------------------------------------------------===// |
| // Helper Methods |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// If the given type is a vector type, return the vector's element type. |
| /// Otherwise return the given type unchanged. |
| // TODO(cir): Return the vector element type once we have support for vectors |
| // instead of the identity type. |
| mlir::Type elementTypeIfVector(mlir::Type type) { |
| assert(!cir::MissingFeatures::vectorType()); |
| return type; |
| } |
| } // namespace |
| |
| /// Given a type convertor and a data layout, convert the given type to a type |
| /// that is suitable for memory operations. For example, this can be used to |
| /// lower cir.bool accesses to i8. |
| static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter, |
| mlir::DataLayout const &dataLayout, |
| mlir::Type type) { |
| // TODO(cir): Handle other types similarly to clang's codegen |
| // convertTypeForMemory |
| if (isa<cir::BoolType>(type)) { |
| return mlir::IntegerType::get(type.getContext(), |
| dataLayout.getTypeSizeInBits(type)); |
| } |
| |
| return converter.convertType(type); |
| } |
| |
| static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src, |
| mlir::IntegerType dstTy, |
| bool isSigned = false) { |
| mlir::Type srcTy = src.getType(); |
| assert(mlir::isa<mlir::IntegerType>(srcTy)); |
| |
| unsigned srcWidth = mlir::cast<mlir::IntegerType>(srcTy).getWidth(); |
| unsigned dstWidth = mlir::cast<mlir::IntegerType>(dstTy).getWidth(); |
| mlir::Location loc = src.getLoc(); |
| |
| if (dstWidth > srcWidth && isSigned) |
| return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src); |
| if (dstWidth > srcWidth) |
| return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src); |
| if (dstWidth < srcWidth) |
| return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src); |
| return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src); |
| } |
| |
| /// Emits the value from memory as expected by its users. Should be called when |
| /// the memory represetnation of a CIR type is not equal to its scalar |
| /// representation. |
| static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter, |
| mlir::DataLayout const &dataLayout, |
| cir::LoadOp op, mlir::Value value) { |
| |
| // TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory |
| if (auto boolTy = mlir::dyn_cast<cir::BoolType>(op.getResult().getType())) { |
| // Create a cast value from specified size in datalayout to i1 |
| assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy))); |
| return createIntCast(rewriter, value, rewriter.getI1Type()); |
| } |
| |
| return value; |
| } |
| |
| /// Emits a value to memory with the expected scalar type. Should be called when |
| /// the memory represetnation of a CIR type is not equal to its scalar |
| /// representation. |
| static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter, |
| mlir::DataLayout const &dataLayout, |
| mlir::Type origType, mlir::Value value) { |
| |
| // TODO(cir): Handle other types similarly to clang's codegen EmitToMemory |
| if (auto boolTy = mlir::dyn_cast<cir::BoolType>(origType)) { |
| // Create zext of value from i1 to i8 |
| mlir::IntegerType memType = |
| rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy)); |
| return createIntCast(rewriter, value, memType); |
| } |
| |
| return value; |
| } |
| |
| mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage) { |
| using CIR = cir::GlobalLinkageKind; |
| using LLVM = mlir::LLVM::Linkage; |
| |
| switch (linkage) { |
| case CIR::AvailableExternallyLinkage: |
| return LLVM::AvailableExternally; |
| case CIR::CommonLinkage: |
| return LLVM::Common; |
| case CIR::ExternalLinkage: |
| return LLVM::External; |
| case CIR::ExternalWeakLinkage: |
| return LLVM::ExternWeak; |
| case CIR::InternalLinkage: |
| return LLVM::Internal; |
| case CIR::LinkOnceAnyLinkage: |
| return LLVM::Linkonce; |
| case CIR::LinkOnceODRLinkage: |
| return LLVM::LinkonceODR; |
| case CIR::PrivateLinkage: |
| return LLVM::Private; |
| case CIR::WeakAnyLinkage: |
| return LLVM::Weak; |
| case CIR::WeakODRLinkage: |
| return LLVM::WeakODR; |
| }; |
| llvm_unreachable("Unknown CIR linkage type"); |
| } |
| |
| static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, |
| mlir::Value llvmSrc, mlir::Type llvmDstIntTy, |
| bool isUnsigned, uint64_t cirSrcWidth, |
| uint64_t cirDstIntWidth) { |
| if (cirSrcWidth == cirDstIntWidth) |
| return llvmSrc; |
| |
| auto loc = llvmSrc.getLoc(); |
| if (cirSrcWidth < cirDstIntWidth) { |
| if (isUnsigned) |
| return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc); |
| return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc); |
| } |
| |
| // Otherwise truncate |
| return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc); |
| } |
| |
| class CIRAttrToValue { |
| public: |
| CIRAttrToValue(mlir::Operation *parentOp, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) |
| : parentOp(parentOp), rewriter(rewriter), converter(converter) {} |
| |
| mlir::Value visit(mlir::Attribute attr) { |
| return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr) |
| .Case<cir::IntAttr, cir::FPAttr, cir::ConstArrayAttr, |
| cir::ConstVectorAttr, cir::ConstPtrAttr, cir::ZeroAttr>( |
| [&](auto attrT) { return visitCirAttr(attrT); }) |
| .Default([&](auto attrT) { return mlir::Value(); }); |
| } |
| |
| mlir::Value visitCirAttr(cir::IntAttr intAttr); |
| mlir::Value visitCirAttr(cir::FPAttr fltAttr); |
| mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr); |
| mlir::Value visitCirAttr(cir::ConstArrayAttr attr); |
| mlir::Value visitCirAttr(cir::ConstVectorAttr attr); |
| mlir::Value visitCirAttr(cir::ZeroAttr attr); |
| |
| private: |
| mlir::Operation *parentOp; |
| mlir::ConversionPatternRewriter &rewriter; |
| const mlir::TypeConverter *converter; |
| }; |
| |
| /// Switches on the type of attribute and calls the appropriate conversion. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| const mlir::Attribute attr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| CIRAttrToValue valueConverter(parentOp, rewriter, converter); |
| mlir::Value value = valueConverter.visit(attr); |
| if (!value) |
| llvm_unreachable("unhandled attribute type"); |
| return value; |
| } |
| |
| /// IntAttr visitor. |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) { |
| mlir::Location loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(intAttr.getType()), intAttr.getValue()); |
| } |
| |
| /// ConstPtrAttr visitor. |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) { |
| mlir::Location loc = parentOp->getLoc(); |
| if (ptrAttr.isNullValue()) { |
| return rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(ptrAttr.getType())); |
| } |
| mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); |
| mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), |
| ptrAttr.getValue().getInt()); |
| return rewriter.create<mlir::LLVM::IntToPtrOp>( |
| loc, converter->convertType(ptrAttr.getType()), ptrVal); |
| } |
| |
| /// FPAttr visitor. |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) { |
| mlir::Location loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); |
| } |
| |
| // ConstArrayAttr visitor |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) { |
| mlir::Type llvmTy = converter->convertType(attr.getType()); |
| mlir::Location loc = parentOp->getLoc(); |
| mlir::Value result; |
| |
| if (attr.hasTrailingZeros()) { |
| mlir::Type arrayTy = attr.getType(); |
| result = rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(arrayTy)); |
| } else { |
| result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| } |
| |
| // Iteratively lower each constant element of the array. |
| if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts())) { |
| for (auto [idx, elt] : llvm::enumerate(arrayAttr)) { |
| mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>()); |
| mlir::Value init = visit(elt); |
| result = |
| rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| } else { |
| llvm_unreachable("unexpected ConstArrayAttr elements"); |
| } |
| |
| return result; |
| } |
| |
| /// ConstVectorAttr visitor. |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) { |
| const mlir::Type llvmTy = converter->convertType(attr.getType()); |
| const mlir::Location loc = parentOp->getLoc(); |
| |
| SmallVector<mlir::Attribute> mlirValues; |
| for (const mlir::Attribute elementAttr : attr.getElts()) { |
| mlir::Attribute mlirAttr; |
| if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) { |
| mlirAttr = rewriter.getIntegerAttr( |
| converter->convertType(intAttr.getType()), intAttr.getValue()); |
| } else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) { |
| mlirAttr = rewriter.getFloatAttr( |
| converter->convertType(floatAttr.getType()), floatAttr.getValue()); |
| } else { |
| llvm_unreachable( |
| "vector constant with an element that is neither an int nor a float"); |
| } |
| mlirValues.push_back(mlirAttr); |
| } |
| |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmTy, |
| mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), |
| mlirValues)); |
| } |
| |
| /// ZeroAttr visitor. |
| mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) { |
| mlir::Location loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(attr.getType())); |
| } |
| |
| // This class handles rewriting initializer attributes for types that do not |
| // require region initialization. |
| class GlobalInitAttrRewriter { |
| public: |
| GlobalInitAttrRewriter(mlir::Type type, |
| mlir::ConversionPatternRewriter &rewriter) |
| : llvmType(type), rewriter(rewriter) {} |
| |
| mlir::Attribute visit(mlir::Attribute attr) { |
| return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr) |
| .Case<cir::IntAttr, cir::FPAttr, cir::BoolAttr>( |
| [&](auto attrT) { return visitCirAttr(attrT); }) |
| .Default([&](auto attrT) { return mlir::Attribute(); }); |
| } |
| |
| mlir::Attribute visitCirAttr(cir::IntAttr attr) { |
| return rewriter.getIntegerAttr(llvmType, attr.getValue()); |
| } |
| mlir::Attribute visitCirAttr(cir::FPAttr attr) { |
| return rewriter.getFloatAttr(llvmType, attr.getValue()); |
| } |
| mlir::Attribute visitCirAttr(cir::BoolAttr attr) { |
| return rewriter.getBoolAttr(attr.getValue()); |
| } |
| |
| private: |
| mlir::Type llvmType; |
| mlir::ConversionPatternRewriter &rewriter; |
| }; |
| |
| // This pass requires the CIR to be in a "flat" state. All blocks in each |
| // function must belong to the parent region. Once scopes and control flow |
| // are implemented in CIR, a pass will be run before this one to flatten |
| // the CIR and get it into the state that this pass requires. |
| struct ConvertCIRToLLVMPass |
| : public mlir::PassWrapper<ConvertCIRToLLVMPass, |
| mlir::OperationPass<mlir::ModuleOp>> { |
| void getDependentDialects(mlir::DialectRegistry ®istry) const override { |
| registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect, |
| mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>(); |
| } |
| void runOnOperation() final; |
| |
| void processCIRAttrs(mlir::ModuleOp module); |
| |
| StringRef getDescription() const override { |
| return "Convert the prepared CIR dialect module to LLVM dialect"; |
| } |
| |
| StringRef getArgument() const override { return "cir-flat-to-llvm"; } |
| }; |
| |
| mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite( |
| cir::BrCondOp brOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| // When ZExtOp is implemented, we'll need to check if the condition is a |
| // ZExtOp and if so, delete it if it has a single use. |
| assert(!cir::MissingFeatures::zextOp()); |
| |
| mlir::Value i1Condition = adaptor.getCond(); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>( |
| brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), |
| brOp.getDestFalse(), adaptor.getDestOperandsFalse()); |
| |
| return mlir::success(); |
| } |
| |
| mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const { |
| return getTypeConverter()->convertType(ty); |
| } |
| |
| mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite( |
| cir::CastOp castOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| // For arithmetic conversions, LLVM IR uses the same instruction to convert |
| // both individual scalars and entire vectors. This lowering pass handles |
| // both situations. |
| |
| switch (castOp.getKind()) { |
| case cir::CastKind::array_to_ptrdecay: { |
| const auto ptrTy = mlir::cast<cir::PointerType>(castOp.getType()); |
| mlir::Value sourceValue = adaptor.getOperands().front(); |
| mlir::Type targetType = convertTy(ptrTy); |
| mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout, |
| ptrTy.getPointee()); |
| llvm::SmallVector<mlir::LLVM::GEPArg> offset{0}; |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| castOp, targetType, elementTy, sourceValue, offset); |
| break; |
| } |
| case cir::CastKind::int_to_bool: { |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>( |
| castOp.getLoc(), llvmSrcVal.getType(), |
| mlir::IntegerAttr::get(llvmSrcVal.getType(), 0)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt); |
| break; |
| } |
| case cir::CastKind::integral: { |
| mlir::Type srcType = castOp.getSrc().getType(); |
| mlir::Type dstType = castOp.getResult().getType(); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstType = getTypeConverter()->convertType(dstType); |
| cir::IntType srcIntType = |
| mlir::cast<cir::IntType>(elementTypeIfVector(srcType)); |
| cir::IntType dstIntType = |
| mlir::cast<cir::IntType>(elementTypeIfVector(dstType)); |
| rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType, |
| srcIntType.isUnsigned(), |
| srcIntType.getWidth(), |
| dstIntType.getWidth())); |
| break; |
| } |
| case cir::CastKind::floating: { |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = |
| getTypeConverter()->convertType(castOp.getResult().getType()); |
| |
| mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType()); |
| mlir::Type dstTy = elementTypeIfVector(castOp.getResult().getType()); |
| |
| if (!mlir::isa<cir::CIRFPTypeInterface>(dstTy) || |
| !mlir::isa<cir::CIRFPTypeInterface>(srcTy)) |
| return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy; |
| |
| auto getFloatWidth = [](mlir::Type ty) -> unsigned { |
| return mlir::cast<cir::CIRFPTypeInterface>(ty).getWidth(); |
| }; |
| |
| if (getFloatWidth(srcTy) > getFloatWidth(dstTy)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::int_to_ptr: { |
| auto dstTy = mlir::cast<cir::PointerType>(castOp.getType()); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::ptr_to_int: { |
| auto dstTy = mlir::cast<cir::IntType>(castOp.getType()); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::float_to_bool: { |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| auto kind = mlir::LLVM::FCmpPredicate::une; |
| |
| // Check if float is not equal to zero. |
| auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>( |
| castOp.getLoc(), llvmSrcVal.getType(), |
| mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); |
| |
| // Extend comparison result to either bool (C++) or int (C). |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(castOp, kind, llvmSrcVal, |
| zeroFloat); |
| |
| return mlir::success(); |
| } |
| case cir::CastKind::bool_to_int: { |
| auto dstTy = mlir::cast<cir::IntType>(castOp.getType()); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType()); |
| auto llvmDstTy = |
| mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy)); |
| if (llvmSrcTy.getWidth() == llvmDstTy.getWidth()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::bool_to_float: { |
| mlir::Type dstTy = castOp.getType(); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::int_to_float: { |
| mlir::Type dstTy = castOp.getType(); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getSrc().getType())) |
| .isSigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::float_to_int: { |
| mlir::Type dstTy = castOp.getType(); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| if (mlir::cast<cir::IntType>( |
| elementTypeIfVector(castOp.getResult().getType())) |
| .isSigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case cir::CastKind::bitcast: |
| assert(!MissingFeatures::cxxABI()); |
| assert(!MissingFeatures::dataMemberType()); |
| break; |
| case cir::CastKind::ptr_to_bool: { |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>( |
| castOp.getLoc(), llvmSrcVal.getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr); |
| break; |
| } |
| case cir::CastKind::address_space: { |
| mlir::Type dstTy = castOp.getType(); |
| mlir::Value llvmSrcVal = adaptor.getOperands().front(); |
| mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| break; |
| } |
| case cir::CastKind::member_ptr_to_bool: |
| assert(!MissingFeatures::cxxABI()); |
| assert(!MissingFeatures::methodType()); |
| break; |
| default: { |
| return castOp.emitError("Unhandled cast kind: ") |
| << castOp.getKindAttrName(); |
| } |
| } |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( |
| cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| |
| const mlir::TypeConverter *tc = getTypeConverter(); |
| const mlir::Type resultTy = tc->convertType(ptrStrideOp.getType()); |
| |
| mlir::Type elementTy = |
| convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy()); |
| mlir::MLIRContext *ctx = elementTy.getContext(); |
| |
| // void and function types doesn't really have a layout to use in GEPs, |
| // make it i8 instead. |
| if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) || |
| mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy)) |
| elementTy = mlir::IntegerType::get(elementTy.getContext(), 8, |
| mlir::IntegerType::Signless); |
| // Zero-extend, sign-extend or trunc the pointer value. |
| mlir::Value index = adaptor.getStride(); |
| const unsigned width = |
| mlir::cast<mlir::IntegerType>(index.getType()).getWidth(); |
| const std::optional<std::uint64_t> layoutWidth = |
| dataLayout.getTypeIndexBitwidth(adaptor.getBase().getType()); |
| |
| mlir::Operation *indexOp = index.getDefiningOp(); |
| if (indexOp && layoutWidth && width != *layoutWidth) { |
| // If the index comes from a subtraction, make sure the extension happens |
| // before it. To achieve that, look at unary minus, which already got |
| // lowered to "sub 0, x". |
| const auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp); |
| auto unary = dyn_cast_if_present<cir::UnaryOp>( |
| ptrStrideOp.getStride().getDefiningOp()); |
| bool rewriteSub = |
| unary && unary.getKind() == cir::UnaryOpKind::Minus && sub; |
| if (rewriteSub) |
| index = indexOp->getOperand(1); |
| |
| // Handle the cast |
| const auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth); |
| index = getLLVMIntCast(rewriter, index, llvmDstType, |
| ptrStrideOp.getStride().getType().isUnsigned(), |
| width, *layoutWidth); |
| |
| // Rewrite the sub in front of extensions/trunc |
| if (rewriteSub) { |
| index = rewriter.create<mlir::LLVM::SubOp>( |
| index.getLoc(), index.getType(), |
| rewriter.create<mlir::LLVM::ConstantOp>( |
| index.getLoc(), index.getType(), |
| mlir::IntegerAttr::get(index.getType(), 0)), |
| index); |
| rewriter.eraseOp(sub); |
| } |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite( |
| cir::AllocaOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| assert(!cir::MissingFeatures::opAllocaDynAllocSize()); |
| mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), |
| rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); |
| mlir::Type elementTy = |
| convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType()); |
| mlir::Type resultTy = convertTypeForMemory(*getTypeConverter(), dataLayout, |
| op.getResult().getType()); |
| |
| assert(!cir::MissingFeatures::addressSpace()); |
| assert(!cir::MissingFeatures::opAllocaAnnotations()); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>( |
| op, resultTy, elementTy, size, op.getAlignmentAttr().getInt()); |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite( |
| cir::ReturnOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, adaptor.getOperands()); |
| return mlir::LogicalResult::success(); |
| } |
| |
| static mlir::LogicalResult |
| rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter, |
| mlir::FlatSymbolRefAttr calleeAttr) { |
| llvm::SmallVector<mlir::Type, 8> llvmResults; |
| mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes(); |
| |
| if (converter->convertTypes(cirResults, llvmResults).failed()) |
| return mlir::failure(); |
| |
| assert(!cir::MissingFeatures::opCallCallConv()); |
| assert(!cir::MissingFeatures::opCallSideEffect()); |
| |
| mlir::LLVM::LLVMFunctionType llvmFnTy; |
| if (calleeAttr) { // direct call |
| mlir::FunctionOpInterface fn = |
| mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>( |
| op, calleeAttr); |
| assert(fn && "Did not find function for call"); |
| llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( |
| converter->convertType(fn.getFunctionType())); |
| } else { // indirect call |
| assert(!cir::MissingFeatures::opCallIndirect()); |
| return op->emitError("Indirect calls are NYI"); |
| } |
| |
| assert(!cir::MissingFeatures::opCallLandingPad()); |
| assert(!cir::MissingFeatures::opCallContinueBlock()); |
| assert(!cir::MissingFeatures::opCallCallConv()); |
| assert(!cir::MissingFeatures::opCallSideEffect()); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmFnTy, calleeAttr, |
| callOperands); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite( |
| cir::CallOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter, |
| getTypeConverter(), op.getCalleeAttr()); |
| } |
| |
| mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite( |
| cir::LoadOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| const mlir::Type llvmTy = convertTypeForMemory( |
| *getTypeConverter(), dataLayout, op.getResult().getType()); |
| assert(!cir::MissingFeatures::opLoadStoreMemOrder()); |
| assert(!cir::MissingFeatures::opLoadStoreAlignment()); |
| unsigned alignment = (unsigned)dataLayout.getTypeABIAlignment(llvmTy); |
| |
| assert(!cir::MissingFeatures::lowerModeOptLevel()); |
| |
| // TODO: nontemporal, syncscope. |
| assert(!cir::MissingFeatures::opLoadStoreVolatile()); |
| mlir::LLVM::LoadOp newLoad = rewriter.create<mlir::LLVM::LoadOp>( |
| op->getLoc(), llvmTy, adaptor.getAddr(), alignment, |
| /*volatile=*/false, /*nontemporal=*/false, |
| /*invariant=*/false, /*invariantGroup=*/false, |
| mlir::LLVM::AtomicOrdering::not_atomic); |
| |
| // Convert adapted result to its original type if needed. |
| mlir::Value result = |
| emitFromMemory(rewriter, dataLayout, op, newLoad.getResult()); |
| rewriter.replaceOp(op, result); |
| assert(!cir::MissingFeatures::opLoadStoreTbaa()); |
| return mlir::LogicalResult::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite( |
| cir::StoreOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| assert(!cir::MissingFeatures::opLoadStoreMemOrder()); |
| assert(!cir::MissingFeatures::opLoadStoreAlignment()); |
| const mlir::Type llvmTy = |
| getTypeConverter()->convertType(op.getValue().getType()); |
| unsigned alignment = (unsigned)dataLayout.getTypeABIAlignment(llvmTy); |
| |
| assert(!cir::MissingFeatures::lowerModeOptLevel()); |
| |
| // Convert adapted value to its memory type if needed. |
| mlir::Value value = emitToMemory(rewriter, dataLayout, |
| op.getValue().getType(), adaptor.getValue()); |
| // TODO: nontemporal, syncscope. |
| assert(!cir::MissingFeatures::opLoadStoreVolatile()); |
| mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>( |
| op->getLoc(), value, adaptor.getAddr(), alignment, /*volatile=*/false, |
| /*nontemporal=*/false, /*invariantGroup=*/false, |
| mlir::LLVM::AtomicOrdering::not_atomic); |
| rewriter.replaceOp(op, storeOp); |
| assert(!cir::MissingFeatures::opLoadStoreTbaa()); |
| return mlir::LogicalResult::success(); |
| } |
| |
| bool hasTrailingZeros(cir::ConstArrayAttr attr) { |
| auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts()); |
| return attr.hasTrailingZeros() || |
| (array && std::count_if(array.begin(), array.end(), [](auto elt) { |
| auto ar = dyn_cast<cir::ConstArrayAttr>(elt); |
| return ar && hasTrailingZeros(ar); |
| })); |
| } |
| |
| mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite( |
| cir::ConstantOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| mlir::Attribute attr = op.getValue(); |
| |
| if (mlir::isa<mlir::IntegerType>(op.getType())) { |
| // Verified cir.const operations cannot actually be of these types, but the |
| // lowering pass may generate temporary cir.const operations with these |
| // types. This is OK since MLIR allows unverified operations to be alive |
| // during a pass as long as they don't live past the end of the pass. |
| attr = op.getValue(); |
| } else if (mlir::isa<cir::BoolType>(op.getType())) { |
| int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue(); |
| attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), |
| value); |
| } else if (mlir::isa<cir::IntType>(op.getType())) { |
| assert(!cir::MissingFeatures::opGlobalViewAttr()); |
| |
| attr = rewriter.getIntegerAttr( |
| typeConverter->convertType(op.getType()), |
| mlir::cast<cir::IntAttr>(op.getValue()).getValue()); |
| } else if (mlir::isa<cir::CIRFPTypeInterface>(op.getType())) { |
| attr = rewriter.getFloatAttr( |
| typeConverter->convertType(op.getType()), |
| mlir::cast<cir::FPAttr>(op.getValue()).getValue()); |
| } else if (mlir::isa<cir::PointerType>(op.getType())) { |
| // Optimize with dedicated LLVM op for null pointers. |
| if (mlir::isa<cir::ConstPtrAttr>(op.getValue())) { |
| if (mlir::cast<cir::ConstPtrAttr>(op.getValue()).isNullValue()) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>( |
| op, typeConverter->convertType(op.getType())); |
| return mlir::success(); |
| } |
| } |
| assert(!cir::MissingFeatures::opGlobalViewAttr()); |
| attr = op.getValue(); |
| } else if (const auto arrTy = mlir::dyn_cast<cir::ArrayType>(op.getType())) { |
| const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(op.getValue()); |
| if (!constArr && !isa<cir::ZeroAttr, cir::UndefAttr>(op.getValue())) |
| return op.emitError() << "array does not have a constant initializer"; |
| |
| std::optional<mlir::Attribute> denseAttr; |
| if (constArr && hasTrailingZeros(constArr)) { |
| const mlir::Value newOp = |
| lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); |
| rewriter.replaceOp(op, newOp); |
| return mlir::success(); |
| } else if (constArr && |
| (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { |
| attr = denseAttr.value(); |
| } else { |
| const mlir::Value initVal = |
| lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); |
| rewriter.replaceAllUsesWith(op, initVal); |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| } else { |
| return op.emitError() << "unsupported constant type " << op.getType(); |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( |
| op, getTypeConverter()->convertType(op.getType()), attr); |
| |
| return mlir::success(); |
| } |
| |
| /// Convert the `cir.func` attributes to `llvm.func` attributes. |
| /// Only retain those attributes that are not constructed by |
| /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out |
| /// argument attributes. |
| void CIRToLLVMFuncOpLowering::lowerFuncAttributes( |
| cir::FuncOp func, bool filterArgAndResAttrs, |
| SmallVectorImpl<mlir::NamedAttribute> &result) const { |
| assert(!cir::MissingFeatures::opFuncCallingConv()); |
| for (mlir::NamedAttribute attr : func->getAttrs()) { |
| if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() || |
| attr.getName() == func.getFunctionTypeAttrName() || |
| attr.getName() == getLinkageAttrNameString() || |
| (filterArgAndResAttrs && |
| (attr.getName() == func.getArgAttrsAttrName() || |
| attr.getName() == func.getResAttrsAttrName()))) |
| continue; |
| |
| assert(!cir::MissingFeatures::opFuncExtraAttrs()); |
| result.push_back(attr); |
| } |
| } |
| |
| mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite( |
| cir::FuncOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| |
| cir::FuncType fnType = op.getFunctionType(); |
| assert(!cir::MissingFeatures::opFuncDsolocal()); |
| bool isDsoLocal = false; |
| mlir::TypeConverter::SignatureConversion signatureConversion( |
| fnType.getNumInputs()); |
| |
| for (const auto &argType : llvm::enumerate(fnType.getInputs())) { |
| mlir::Type convertedType = typeConverter->convertType(argType.value()); |
| if (!convertedType) |
| return mlir::failure(); |
| signatureConversion.addInputs(argType.index(), convertedType); |
| } |
| |
| mlir::Type resultType = |
| getTypeConverter()->convertType(fnType.getReturnType()); |
| |
| // Create the LLVM function operation. |
| mlir::Type llvmFnTy = mlir::LLVM::LLVMFunctionType::get( |
| resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), |
| signatureConversion.getConvertedTypes(), |
| /*isVarArg=*/fnType.isVarArg()); |
| // LLVMFuncOp expects a single FileLine Location instead of a fused |
| // location. |
| mlir::Location loc = op.getLoc(); |
| if (mlir::FusedLoc fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc)) |
| loc = fusedLoc.getLocations()[0]; |
| assert((mlir::isa<mlir::FileLineColLoc>(loc) || |
| mlir::isa<mlir::UnknownLoc>(loc)) && |
| "expected single location or unknown location here"); |
| |
| assert(!cir::MissingFeatures::opFuncLinkage()); |
| mlir::LLVM::Linkage linkage = mlir::LLVM::Linkage::External; |
| assert(!cir::MissingFeatures::opFuncCallingConv()); |
| mlir::LLVM::CConv cconv = mlir::LLVM::CConv::C; |
| SmallVector<mlir::NamedAttribute, 4> attributes; |
| lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); |
| |
| mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, |
| mlir::SymbolRefAttr(), attributes); |
| |
| assert(!cir::MissingFeatures::opFuncVisibility()); |
| |
| rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
| if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter, |
| &signatureConversion))) |
| return mlir::failure(); |
| |
| rewriter.eraseOp(op); |
| |
| return mlir::LogicalResult::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite( |
| cir::GetGlobalOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. |
| // CIRGen should mitigate this and not emit the get_global. |
| if (op->getUses().empty()) { |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| |
| mlir::Type type = getTypeConverter()->convertType(op.getType()); |
| mlir::Operation *newop = |
| rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName()); |
| |
| assert(!cir::MissingFeatures::opGlobalThreadLocal()); |
| |
| rewriter.replaceOp(op, newop); |
| return mlir::success(); |
| } |
| |
| /// Replace CIR global with a region initialized LLVM global and update |
| /// insertion point to the end of the initializer block. |
| void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp( |
| cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { |
| const mlir::Type llvmType = |
| convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType()); |
| |
| // FIXME: These default values are placeholders until the the equivalent |
| // attributes are available on cir.global ops. This duplicates code |
| // in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go |
| // away when the placeholders are no longer needed. |
| assert(!cir::MissingFeatures::opGlobalConstant()); |
| const bool isConst = false; |
| assert(!cir::MissingFeatures::addressSpace()); |
| const unsigned addrSpace = 0; |
| assert(!cir::MissingFeatures::opGlobalDSOLocal()); |
| const bool isDsoLocal = true; |
| assert(!cir::MissingFeatures::opGlobalThreadLocal()); |
| const bool isThreadLocal = false; |
| assert(!cir::MissingFeatures::opGlobalAlignment()); |
| const uint64_t alignment = 0; |
| const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage()); |
| const StringRef symbol = op.getSymName(); |
| |
| SmallVector<mlir::NamedAttribute> attributes; |
| mlir::LLVM::GlobalOp newGlobalOp = |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( |
| op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace, |
| isDsoLocal, isThreadLocal, |
| /*comdat=*/mlir::SymbolRefAttr(), attributes); |
| newGlobalOp.getRegion().emplaceBlock(); |
| rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); |
| } |
| |
| mlir::LogicalResult |
| CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal( |
| cir::GlobalOp op, mlir::Attribute init, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| // TODO: Generalize this handling when more types are needed here. |
| assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr, |
| cir::ZeroAttr>(init))); |
| |
| // TODO(cir): once LLVM's dialect has proper equivalent attributes this |
| // should be updated. For now, we use a custom op to initialize globals |
| // to the appropriate value. |
| const mlir::Location loc = op.getLoc(); |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| CIRAttrToValue valueConverter(op, rewriter, typeConverter); |
| mlir::Value value = valueConverter.visit(init); |
| rewriter.create<mlir::LLVM::ReturnOp>(loc, value); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite( |
| cir::GlobalOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| |
| std::optional<mlir::Attribute> init = op.getInitialValue(); |
| |
| // Fetch required values to create LLVM op. |
| const mlir::Type cirSymType = op.getSymType(); |
| |
| // This is the LLVM dialect type. |
| const mlir::Type llvmType = |
| convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType); |
| // FIXME: These default values are placeholders until the the equivalent |
| // attributes are available on cir.global ops. |
| assert(!cir::MissingFeatures::opGlobalConstant()); |
| const bool isConst = false; |
| assert(!cir::MissingFeatures::addressSpace()); |
| const unsigned addrSpace = 0; |
| assert(!cir::MissingFeatures::opGlobalDSOLocal()); |
| const bool isDsoLocal = true; |
| assert(!cir::MissingFeatures::opGlobalThreadLocal()); |
| const bool isThreadLocal = false; |
| assert(!cir::MissingFeatures::opGlobalAlignment()); |
| const uint64_t alignment = 0; |
| const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage()); |
| const StringRef symbol = op.getSymName(); |
| SmallVector<mlir::NamedAttribute> attributes; |
| |
| if (init.has_value()) { |
| if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) { |
| GlobalInitAttrRewriter initRewriter(llvmType, rewriter); |
| init = initRewriter.visit(init.value()); |
| // If initRewriter returned a null attribute, init will have a value but |
| // the value will be null. If that happens, initRewriter didn't handle the |
| // attribute type. It probably needs to be added to |
| // GlobalInitAttrRewriter. |
| if (!init.value()) { |
| op.emitError() << "unsupported initializer '" << init.value() << "'"; |
| return mlir::failure(); |
| } |
| } else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr, |
| cir::ConstPtrAttr, cir::ZeroAttr>(init.value())) { |
| // TODO(cir): once LLVM's dialect has proper equivalent attributes this |
| // should be updated. For now, we use a custom op to initialize globals |
| // to the appropriate value. |
| return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter); |
| } else { |
| // We will only get here if new initializer types are added and this |
| // code is not updated to handle them. |
| op.emitError() << "unsupported initializer '" << init.value() << "'"; |
| return mlir::failure(); |
| } |
| } |
| |
| // Rewrite op. |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( |
| op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()), |
| alignment, addrSpace, isDsoLocal, isThreadLocal, |
| /*comdat=*/mlir::SymbolRefAttr(), attributes); |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite( |
| cir::UnaryOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| assert(op.getType() == op.getInput().getType() && |
| "Unary operation's operand type and result type are different"); |
| mlir::Type type = op.getType(); |
| mlir::Type elementType = type; |
| bool isVector = false; |
| assert(!cir::MissingFeatures::vectorType()); |
| mlir::Type llvmType = getTypeConverter()->convertType(type); |
| mlir::Location loc = op.getLoc(); |
| |
| // Integer unary operations: + - ~ ++ -- |
| if (mlir::isa<cir::IntType>(elementType)) { |
| mlir::LLVM::IntegerOverflowFlags maybeNSW = |
| op.getNoSignedWrap() ? mlir::LLVM::IntegerOverflowFlags::nsw |
| : mlir::LLVM::IntegerOverflowFlags::none; |
| switch (op.getKind()) { |
| case cir::UnaryOpKind::Inc: { |
| assert(!isVector && "++ not allowed on vector types"); |
| mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>( |
| op, llvmType, adaptor.getInput(), one, maybeNSW); |
| return mlir::success(); |
| } |
| case cir::UnaryOpKind::Dec: { |
| assert(!isVector && "-- not allowed on vector types"); |
| mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>( |
| op, llvmType, adaptor.getInput(), one, maybeNSW); |
| return mlir::success(); |
| } |
| case cir::UnaryOpKind::Plus: |
| rewriter.replaceOp(op, adaptor.getInput()); |
| return mlir::success(); |
| case cir::UnaryOpKind::Minus: { |
| assert(!isVector && |
| "Add vector handling when vector types are supported"); |
| mlir::LLVM::ConstantOp zero = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 0)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>( |
| op, llvmType, zero, adaptor.getInput(), maybeNSW); |
| return mlir::success(); |
| } |
| case cir::UnaryOpKind::Not: { |
| // bit-wise compliment operator, implemented as an XOR with -1. |
| assert(!isVector && |
| "Add vector handling when vector types are supported"); |
| mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, -1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>( |
| op, llvmType, adaptor.getInput(), minusOne); |
| return mlir::success(); |
| } |
| } |
| llvm_unreachable("Unexpected unary op for int"); |
| } |
| |
| // Floating point unary operations: + - ++ -- |
| if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) { |
| switch (op.getKind()) { |
| case cir::UnaryOpKind::Inc: { |
| assert(!isVector && "++ not allowed on vector types"); |
| mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| case cir::UnaryOpKind::Dec: { |
| assert(!isVector && "-- not allowed on vector types"); |
| mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| case cir::UnaryOpKind::Plus: |
| rewriter.replaceOp(op, adaptor.getInput()); |
| return mlir::success(); |
| case cir::UnaryOpKind::Minus: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType, |
| adaptor.getInput()); |
| return mlir::success(); |
| case cir::UnaryOpKind::Not: |
| return op.emitError() << "Unary not is invalid for floating-point types"; |
| } |
| llvm_unreachable("Unexpected unary op for float"); |
| } |
| |
| // Boolean unary operations: ! only. (For all others, the operand has |
| // already been promoted to int.) |
| if (mlir::isa<cir::BoolType>(elementType)) { |
| switch (op.getKind()) { |
| case cir::UnaryOpKind::Inc: |
| case cir::UnaryOpKind::Dec: |
| case cir::UnaryOpKind::Plus: |
| case cir::UnaryOpKind::Minus: |
| // Some of these are allowed in source code, but we shouldn't get here |
| // with a boolean type. |
| return op.emitError() << "Unsupported unary operation on boolean type"; |
| case cir::UnaryOpKind::Not: { |
| assert(!isVector && "NYI: op! on vector mask"); |
| mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmType, |
| adaptor.getInput(), one); |
| return mlir::success(); |
| } |
| } |
| llvm_unreachable("Unexpected unary op for bool"); |
| } |
| |
| // Pointer unary operations: + only. (++ and -- of pointers are implemented |
| // with cir.ptr_stride, not cir.unary.) |
| if (mlir::isa<cir::PointerType>(elementType)) { |
| return op.emitError() |
| << "Unary operation on pointer types is not yet implemented"; |
| } |
| |
| return op.emitError() << "Unary operation has unsupported type: " |
| << elementType; |
| } |
| |
| mlir::LLVM::IntegerOverflowFlags |
| CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const { |
| if (op.getNoUnsignedWrap()) |
| return mlir::LLVM::IntegerOverflowFlags::nuw; |
| |
| if (op.getNoSignedWrap()) |
| return mlir::LLVM::IntegerOverflowFlags::nsw; |
| |
| return mlir::LLVM::IntegerOverflowFlags::none; |
| } |
| |
| static bool isIntTypeUnsigned(mlir::Type type) { |
| // TODO: Ideally, we should only need to check cir::IntType here. |
| return mlir::isa<cir::IntType>(type) |
| ? mlir::cast<cir::IntType>(type).isUnsigned() |
| : mlir::cast<mlir::IntegerType>(type).isUnsigned(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite( |
| cir::BinOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| if (adaptor.getLhs().getType() != adaptor.getRhs().getType()) |
| return op.emitError() << "inconsistent operands' types not supported yet"; |
| |
| mlir::Type type = op.getRhs().getType(); |
| assert(!cir::MissingFeatures::vectorType()); |
| if (!mlir::isa<cir::IntType, cir::BoolType, cir::CIRFPTypeInterface, |
| mlir::IntegerType>(type)) |
| return op.emitError() << "operand type not supported yet"; |
| |
| auto llvmTy = getTypeConverter()->convertType(op.getType()); |
| mlir::Type llvmEltTy = |
| mlir::isa<mlir::VectorType>(llvmTy) |
| ? mlir::cast<mlir::VectorType>(llvmTy).getElementType() |
| : llvmTy; |
| auto rhs = adaptor.getRhs(); |
| auto lhs = adaptor.getLhs(); |
| |
| type = elementTypeIfVector(type); |
| |
| switch (op.getKind()) { |
| case cir::BinOpKind::Add: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { |
| if (op.getSaturated()) { |
| if (isIntTypeUnsigned(type)) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs); |
| break; |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs); |
| break; |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| } else { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs); |
| } |
| break; |
| case cir::BinOpKind::Sub: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { |
| if (op.getSaturated()) { |
| if (isIntTypeUnsigned(type)) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs); |
| break; |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs); |
| break; |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| } else { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs); |
| } |
| break; |
| case cir::BinOpKind::Mul: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs); |
| break; |
| case cir::BinOpKind::Div: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { |
| auto isUnsigned = isIntTypeUnsigned(type); |
| if (isUnsigned) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs); |
| } else { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs); |
| } |
| break; |
| case cir::BinOpKind::Rem: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { |
| auto isUnsigned = isIntTypeUnsigned(type); |
| if (isUnsigned) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs); |
| } else { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs); |
| } |
| break; |
| case cir::BinOpKind::And: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs); |
| break; |
| case cir::BinOpKind::Or: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs); |
| break; |
| case cir::BinOpKind::Xor: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs); |
| break; |
| case cir::BinOpKind::Max: |
| if (mlir::isa<mlir::IntegerType>(llvmEltTy)) { |
| auto isUnsigned = isIntTypeUnsigned(type); |
| if (isUnsigned) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs); |
| } |
| break; |
| } |
| |
| return mlir::LogicalResult::success(); |
| } |
| |
| /// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. |
| static mlir::LLVM::ICmpPredicate |
| convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) { |
| using CIR = cir::CmpOpKind; |
| using LLVMICmp = mlir::LLVM::ICmpPredicate; |
| switch (kind) { |
| case CIR::eq: |
| return LLVMICmp::eq; |
| case CIR::ne: |
| return LLVMICmp::ne; |
| case CIR::lt: |
| return (isSigned ? LLVMICmp::slt : LLVMICmp::ult); |
| case CIR::le: |
| return (isSigned ? LLVMICmp::sle : LLVMICmp::ule); |
| case CIR::gt: |
| return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt); |
| case CIR::ge: |
| return (isSigned ? LLVMICmp::sge : LLVMICmp::uge); |
| } |
| llvm_unreachable("Unknown CmpOpKind"); |
| } |
| |
| /// Convert from a CIR comparison kind to an LLVM IR floating-point comparison |
| /// kind. |
| static mlir::LLVM::FCmpPredicate |
| convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) { |
| using CIR = cir::CmpOpKind; |
| using LLVMFCmp = mlir::LLVM::FCmpPredicate; |
| switch (kind) { |
| case CIR::eq: |
| return LLVMFCmp::oeq; |
| case CIR::ne: |
| return LLVMFCmp::une; |
| case CIR::lt: |
| return LLVMFCmp::olt; |
| case CIR::le: |
| return LLVMFCmp::ole; |
| case CIR::gt: |
| return LLVMFCmp::ogt; |
| case CIR::ge: |
| return LLVMFCmp::oge; |
| } |
| llvm_unreachable("Unknown CmpOpKind"); |
| } |
| |
| mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( |
| cir::CmpOp cmpOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| mlir::Type type = cmpOp.getLhs().getType(); |
| |
| assert(!cir::MissingFeatures::dataMemberType()); |
| assert(!cir::MissingFeatures::methodType()); |
| |
| // Lower to LLVM comparison op. |
| if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) { |
| bool isSigned = mlir::isa<cir::IntType>(type) |
| ? mlir::cast<cir::IntType>(type).isSigned() |
| : mlir::cast<mlir::IntegerType>(type).isSigned(); |
| mlir::LLVM::ICmpPredicate kind = |
| convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) { |
| mlir::LLVM::ICmpPredicate kind = |
| convertCmpKindToICmpPredicate(cmpOp.getKind(), |
| /* isSigned=*/false); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>( |
| cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else if (mlir::isa<cir::CIRFPTypeInterface>(type)) { |
| mlir::LLVM::FCmpPredicate kind = |
| convertCmpKindToFCmpPredicate(cmpOp.getKind()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>( |
| cmpOp, kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else { |
| return cmpOp.emitError() << "unsupported type for CmpOp: " << type; |
| } |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite( |
| cir::ShiftOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType()); |
| auto cirValTy = mlir::dyn_cast<cir::IntType>(op.getValue().getType()); |
| |
| // Operands could also be vector type |
| assert(!cir::MissingFeatures::vectorType()); |
| mlir::Type llvmTy = getTypeConverter()->convertType(op.getType()); |
| mlir::Value amt = adaptor.getAmount(); |
| mlir::Value val = adaptor.getValue(); |
| |
| // TODO(cir): Assert for vector types |
| assert((cirValTy && cirAmtTy) && |
| "shift input type must be integer or vector type, otherwise NYI"); |
| |
| assert((cirValTy == op.getType()) && "inconsistent operands' types NYI"); |
| |
| // Ensure shift amount is the same type as the value. Some undefined |
| // behavior might occur in the casts below as per [C99 6.5.7.3]. |
| // Vector type shift amount needs no cast as type consistency is expected to |
| // be already be enforced at CIRGen. |
| if (cirAmtTy) |
| amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy), |
| true, cirAmtTy.getWidth(), cirValTy.getWidth()); |
| |
| // Lower to the proper LLVM shift operation. |
| if (op.getIsShiftleft()) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt); |
| } else { |
| assert(!cir::MissingFeatures::vectorType()); |
| bool isUnsigned = !cirValTy.isSigned(); |
| if (isUnsigned) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt); |
| } |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite( |
| cir::SelectOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr { |
| auto definingOp = |
| mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp()); |
| if (!definingOp) |
| return {}; |
| |
| auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue()); |
| if (!constValue) |
| return {}; |
| |
| return constValue; |
| }; |
| |
| // Two special cases in the LLVMIR codegen of select op: |
| // - select %0, %1, false => and %0, %1 |
| // - select %0, true, %1 => or %0, %1 |
| if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) { |
| cir::BoolAttr trueValue = getConstantBool(op.getTrueValue()); |
| cir::BoolAttr falseValue = getConstantBool(op.getFalseValue()); |
| if (falseValue && !falseValue.getValue()) { |
| // select %0, %1, false => and %0, %1 |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(), |
| adaptor.getTrueValue()); |
| return mlir::success(); |
| } |
| if (trueValue && trueValue.getValue()) { |
| // select %0, true, %1 => or %0, %1 |
| rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(), |
| adaptor.getFalseValue()); |
| return mlir::success(); |
| } |
| } |
| |
| mlir::Value llvmCondition = adaptor.getCondition(); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( |
| op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); |
| |
| return mlir::success(); |
| } |
| |
| static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, |
| mlir::DataLayout &dataLayout) { |
| converter.addConversion([&](cir::PointerType type) -> mlir::Type { |
| // Drop pointee type since LLVM dialect only allows opaque pointers. |
| assert(!cir::MissingFeatures::addressSpace()); |
| unsigned targetAS = 0; |
| |
| return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS); |
| }); |
| converter.addConversion([&](cir::ArrayType type) -> mlir::Type { |
| mlir::Type ty = |
| convertTypeForMemory(converter, dataLayout, type.getElementType()); |
| return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); |
| }); |
| converter.addConversion([&](cir::VectorType type) -> mlir::Type { |
| const mlir::Type ty = converter.convertType(type.getElementType()); |
| return mlir::VectorType::get(type.getSize(), ty); |
| }); |
| converter.addConversion([&](cir::BoolType type) -> mlir::Type { |
| return mlir::IntegerType::get(type.getContext(), 1, |
| mlir::IntegerType::Signless); |
| }); |
| converter.addConversion([&](cir::IntType type) -> mlir::Type { |
| // LLVM doesn't work with signed types, so we drop the CIR signs here. |
| return mlir::IntegerType::get(type.getContext(), type.getWidth()); |
| }); |
| converter.addConversion([&](cir::SingleType type) -> mlir::Type { |
| return mlir::Float32Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::DoubleType type) -> mlir::Type { |
| return mlir::Float64Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::FP80Type type) -> mlir::Type { |
| return mlir::Float80Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::FP128Type type) -> mlir::Type { |
| return mlir::Float128Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type { |
| return converter.convertType(type.getUnderlying()); |
| }); |
| converter.addConversion([&](cir::FP16Type type) -> mlir::Type { |
| return mlir::Float16Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::BF16Type type) -> mlir::Type { |
| return mlir::BFloat16Type::get(type.getContext()); |
| }); |
| converter.addConversion([&](cir::RecordType type) -> mlir::Type { |
| // Convert struct members. |
| llvm::SmallVector<mlir::Type> llvmMembers; |
| switch (type.getKind()) { |
| case cir::RecordType::Struct: |
| for (mlir::Type ty : type.getMembers()) |
| llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty)); |
| break; |
| // Unions are lowered as only the largest member. |
| case cir::RecordType::Union: |
| if (auto largestMember = type.getLargestMember(dataLayout)) |
| llvmMembers.push_back( |
| convertTypeForMemory(converter, dataLayout, largestMember)); |
| if (type.getPadded()) { |
| auto last = *type.getMembers().rbegin(); |
| llvmMembers.push_back( |
| convertTypeForMemory(converter, dataLayout, last)); |
| } |
| break; |
| } |
| |
| // Record has a name: lower as an identified record. |
| mlir::LLVM::LLVMStructType llvmStruct; |
| if (type.getName()) { |
| llvmStruct = mlir::LLVM::LLVMStructType::getIdentified( |
| type.getContext(), type.getPrefixedName()); |
| if (llvmStruct.setBody(llvmMembers, type.getPacked()).failed()) |
| llvm_unreachable("Failed to set body of record"); |
| } else { // Record has no name: lower as literal record. |
| llvmStruct = mlir::LLVM::LLVMStructType::getLiteral( |
| type.getContext(), llvmMembers, type.getPacked()); |
| } |
| |
| return llvmStruct; |
| }); |
| } |
| |
| // The applyPartialConversion function traverses blocks in the dominance order, |
| // so it does not lower and operations that are not reachachable from the |
| // operations passed in as arguments. Since we do need to lower such code in |
| // order to avoid verification errors occur, we cannot just pass the module op |
| // to applyPartialConversion. We must build a set of unreachable ops and |
| // explicitly add them, along with the module, to the vector we pass to |
| // applyPartialConversion. |
| // |
| // For instance, this CIR code: |
| // |
| // cir.func @foo(%arg0: !s32i) -> !s32i { |
| // %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool |
| // cir.if %4 { |
| // %5 = cir.const #cir.int<1> : !s32i |
| // cir.return %5 : !s32i |
| // } else { |
| // %5 = cir.const #cir.int<0> : !s32i |
| // cir.return %5 : !s32i |
| // } |
| // cir.return %arg0 : !s32i |
| // } |
| // |
| // contains an unreachable return operation (the last one). After the flattening |
| // pass it will be placed into the unreachable block. The possible error |
| // after the lowering pass is: error: 'cir.return' op expects parent op to be |
| // one of 'cir.func, cir.scope, cir.if ... The reason that this operation was |
| // not lowered and the new parent is llvm.func. |
| // |
| // In the future we may want to get rid of this function and use a DCE pass or |
| // something similar. But for now we need to guarantee the absence of the |
| // dialect verification errors. |
| static void collectUnreachable(mlir::Operation *parent, |
| llvm::SmallVector<mlir::Operation *> &ops) { |
| |
| llvm::SmallVector<mlir::Block *> unreachableBlocks; |
| parent->walk([&](mlir::Block *blk) { // check |
| if (blk->hasNoPredecessors() && !blk->isEntryBlock()) |
| unreachableBlocks.push_back(blk); |
| }); |
| |
| std::set<mlir::Block *> visited; |
| for (mlir::Block *root : unreachableBlocks) { |
| // We create a work list for each unreachable block. |
| // Thus we traverse operations in some order. |
| std::deque<mlir::Block *> workList; |
| workList.push_back(root); |
| |
| while (!workList.empty()) { |
| mlir::Block *blk = workList.back(); |
| workList.pop_back(); |
| if (visited.count(blk)) |
| continue; |
| visited.emplace(blk); |
| |
| for (mlir::Operation &op : *blk) |
| ops.push_back(&op); |
| |
| for (mlir::Block *succ : blk->getSuccessors()) |
| workList.push_back(succ); |
| } |
| } |
| } |
| |
| void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) { |
| // Lower the module attributes to LLVM equivalents. |
| if (mlir::Attribute tripleAttr = |
| module->getAttr(cir::CIRDialect::getTripleAttrName())) |
| module->setAttr(mlir::LLVM::LLVMDialect::getTargetTripleAttrName(), |
| tripleAttr); |
| } |
| |
| void ConvertCIRToLLVMPass::runOnOperation() { |
| llvm::TimeTraceScope scope("Convert CIR to LLVM Pass"); |
| |
| mlir::ModuleOp module = getOperation(); |
| mlir::DataLayout dl(module); |
| mlir::LLVMTypeConverter converter(&getContext()); |
| prepareTypeConverter(converter, dl); |
| |
| mlir::RewritePatternSet patterns(&getContext()); |
| |
| patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext()); |
| // This could currently be merged with the group below, but it will get more |
| // arguments later, so we'll keep it separate for now. |
| patterns.add<CIRToLLVMAllocaOpLowering>(converter, patterns.getContext(), dl); |
| patterns.add<CIRToLLVMLoadOpLowering>(converter, patterns.getContext(), dl); |
| patterns.add<CIRToLLVMStoreOpLowering>(converter, patterns.getContext(), dl); |
| patterns.add<CIRToLLVMGlobalOpLowering>(converter, patterns.getContext(), dl); |
| patterns.add<CIRToLLVMCastOpLowering>(converter, patterns.getContext(), dl); |
| patterns.add<CIRToLLVMPtrStrideOpLowering>(converter, patterns.getContext(), |
| dl); |
| patterns.add< |
| // clang-format off |
| CIRToLLVMBinOpLowering, |
| CIRToLLVMBrCondOpLowering, |
| CIRToLLVMBrOpLowering, |
| CIRToLLVMCallOpLowering, |
| CIRToLLVMCmpOpLowering, |
| CIRToLLVMConstantOpLowering, |
| CIRToLLVMFuncOpLowering, |
| CIRToLLVMGetGlobalOpLowering, |
| CIRToLLVMGetMemberOpLowering, |
| CIRToLLVMSelectOpLowering, |
| CIRToLLVMShiftOpLowering, |
| CIRToLLVMStackSaveOpLowering, |
| CIRToLLVMStackRestoreOpLowering, |
| CIRToLLVMTrapOpLowering, |
| CIRToLLVMUnaryOpLowering, |
| CIRToLLVMVecCreateOpLowering, |
| CIRToLLVMVecExtractOpLowering |
| // clang-format on |
| >(converter, patterns.getContext()); |
| |
| processCIRAttrs(module); |
| |
| mlir::ConversionTarget target(getContext()); |
| target.addLegalOp<mlir::ModuleOp>(); |
| target.addLegalDialect<mlir::LLVM::LLVMDialect>(); |
| target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect, |
| mlir::func::FuncDialect>(); |
| |
| llvm::SmallVector<mlir::Operation *> ops; |
| ops.push_back(module); |
| collectUnreachable(module, ops); |
| |
| if (failed(applyPartialConversion(ops, target, std::move(patterns)))) |
| signalPassFailure(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite( |
| cir::BrOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(), |
| op.getDest()); |
| return mlir::LogicalResult::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite( |
| cir::GetMemberOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| mlir::Type llResTy = getTypeConverter()->convertType(op.getType()); |
| const auto recordTy = |
| mlir::cast<cir::RecordType>(op.getAddrTy().getPointee()); |
| assert(recordTy && "expected record type"); |
| |
| switch (recordTy.getKind()) { |
| case cir::RecordType::Struct: { |
| // Since the base address is a pointer to an aggregate, the first offset |
| // is always zero. The second offset tell us which member it will access. |
| llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()}; |
| const mlir::Type elementTy = getTypeConverter()->convertType(recordTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy, |
| adaptor.getAddr(), offset); |
| return mlir::success(); |
| } |
| case cir::RecordType::Union: |
| // Union members share the address space, so we just need a bitcast to |
| // conform to type-checking. |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy, |
| adaptor.getAddr()); |
| return mlir::success(); |
| } |
| } |
| |
| mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite( |
| cir::TrapOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| mlir::Location loc = op->getLoc(); |
| rewriter.eraseOp(op); |
| |
| rewriter.create<mlir::LLVM::Trap>(loc); |
| |
| // Note that the call to llvm.trap is not a terminator in LLVM dialect. |
| // So we must emit an additional llvm.unreachable to terminate the current |
| // block. |
| rewriter.create<mlir::LLVM::UnreachableOp>(loc); |
| |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( |
| cir::StackSaveOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| const mlir::Type ptrTy = getTypeConverter()->convertType(op.getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite( |
| cir::StackRestoreOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op, adaptor.getPtr()); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite( |
| cir::VecCreateOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| // Start with an 'undef' value for the vector. Then 'insertelement' for |
| // each of the vector elements. |
| const auto vecTy = mlir::cast<cir::VectorType>(op.getType()); |
| const mlir::Type llvmTy = typeConverter->convertType(vecTy); |
| const mlir::Location loc = op.getLoc(); |
| mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy); |
| assert(vecTy.getSize() == op.getElements().size() && |
| "cir.vec.create op count doesn't match vector type elements count"); |
| |
| for (uint64_t i = 0; i < vecTy.getSize(); ++i) { |
| const mlir::Value indexValue = |
| rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i); |
| result = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, result, adaptor.getElements()[i], indexValue); |
| } |
| |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| |
| mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite( |
| cir::VecExtractOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>( |
| op, adaptor.getVec(), adaptor.getIndex()); |
| return mlir::success(); |
| } |
| |
| std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { |
| return std::make_unique<ConvertCIRToLLVMPass>(); |
| } |
| |
| void populateCIRToLLVMPasses(mlir::OpPassManager &pm) { |
| mlir::populateCIRPreLoweringPasses(pm); |
| pm.addPass(createConvertCIRToLLVMPass()); |
| } |
| |
| std::unique_ptr<llvm::Module> |
| lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) { |
| llvm::TimeTraceScope scope("lower from CIR to LLVM directly"); |
| |
| mlir::MLIRContext *mlirCtx = mlirModule.getContext(); |
| |
| mlir::PassManager pm(mlirCtx); |
| populateCIRToLLVMPasses(pm); |
| |
| (void)mlir::applyPassManagerCLOptions(pm); |
| |
| if (mlir::failed(pm.run(mlirModule))) { |
| // FIXME: Handle any errors where they occurs and return a nullptr here. |
| report_fatal_error( |
| "The pass manager failed to lower CIR to LLVMIR dialect!"); |
| } |
| |
| mlir::registerBuiltinDialectTranslation(*mlirCtx); |
| mlir::registerLLVMDialectTranslation(*mlirCtx); |
| mlir::registerCIRDialectTranslation(*mlirCtx); |
| |
| llvm::TimeTraceScope translateScope("translateModuleToLLVMIR"); |
| |
| StringRef moduleName = mlirModule.getName().value_or("CIRToLLVMModule"); |
| std::unique_ptr<llvm::Module> llvmModule = |
| mlir::translateModuleToLLVMIR(mlirModule, llvmCtx, moduleName); |
| |
| if (!llvmModule) { |
| // FIXME: Handle any errors where they occurs and return a nullptr here. |
| report_fatal_error("Lowering from LLVMIR dialect to llvm IR failed!"); |
| } |
| |
| return llvmModule; |
| } |
| } // namespace direct |
| } // namespace cir |