| //====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements lowering of CIR operations to LLVMIR. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "LoweringHelpers.h" |
| #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" |
| #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
| #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" |
| #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
| #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" |
| #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" |
| #include "mlir/Dialect/DLTI/DLTI.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinDialect.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/IR/ValueRange.h" |
| #include "mlir/IR/Visitors.h" |
| #include "mlir/Interfaces/DataLayoutInterfaces.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Pass/PassManager.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
| #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
| #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" |
| #include "mlir/Target/LLVMIR/Export.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| #include "clang/CIR/Dialect/IR/CIRAttrs.h" |
| #include "clang/CIR/Dialect/IR/CIRDialect.h" |
| #include "clang/CIR/Dialect/IR/CIROpsEnums.h" |
| #include "clang/CIR/Dialect/IR/CIRTypes.h" |
| #include "clang/CIR/Dialect/Passes.h" |
| #include "clang/CIR/LoweringHelpers.h" |
| #include "clang/CIR/MissingFeatures.h" |
| #include "clang/CIR/Passes.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/IR/DataLayout.h" |
| #include "llvm/IR/DerivedTypes.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/ErrorHandling.h" |
| #include <cstdint> |
| #include <deque> |
| #include <optional> |
| #include <set> |
| |
| #include "LowerModule.h" |
| |
| using namespace cir; |
| using namespace llvm; |
| |
| namespace cir { |
| namespace direct { |
| |
| //===----------------------------------------------------------------------===// |
| // Helper Methods |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Walks a region while skipping operations of type `Ops`. This ensures the |
| /// callback is not applied to said operations and its children. |
| template <typename... Ops> |
| void walkRegionSkipping(mlir::Region ®ion, |
| mlir::function_ref<void(mlir::Operation *)> callback) { |
| region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) { |
| if (isa<Ops...>(op)) |
| return mlir::WalkResult::skip(); |
| callback(op); |
| return mlir::WalkResult::advance(); |
| }); |
| } |
| |
| /// Convert from a CIR comparison kind to an LLVM IR integral comparison kind. |
| mlir::LLVM::ICmpPredicate |
| convertCmpKindToICmpPredicate(mlir::cir::CmpOpKind kind, bool isSigned) { |
| using CIR = mlir::cir::CmpOpKind; |
| using LLVMICmp = mlir::LLVM::ICmpPredicate; |
| switch (kind) { |
| case CIR::eq: |
| return LLVMICmp::eq; |
| case CIR::ne: |
| return LLVMICmp::ne; |
| case CIR::lt: |
| return (isSigned ? LLVMICmp::slt : LLVMICmp::ult); |
| case CIR::le: |
| return (isSigned ? LLVMICmp::sle : LLVMICmp::ule); |
| case CIR::gt: |
| return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt); |
| case CIR::ge: |
| return (isSigned ? LLVMICmp::sge : LLVMICmp::uge); |
| } |
| llvm_unreachable("Unknown CmpOpKind"); |
| } |
| |
| /// Convert from a CIR comparison kind to an LLVM IR floating-point comparison |
| /// kind. |
| mlir::LLVM::FCmpPredicate |
| convertCmpKindToFCmpPredicate(mlir::cir::CmpOpKind kind) { |
| using CIR = mlir::cir::CmpOpKind; |
| using LLVMFCmp = mlir::LLVM::FCmpPredicate; |
| switch (kind) { |
| case CIR::eq: |
| return LLVMFCmp::oeq; |
| case CIR::ne: |
| return LLVMFCmp::une; |
| case CIR::lt: |
| return LLVMFCmp::olt; |
| case CIR::le: |
| return LLVMFCmp::ole; |
| case CIR::gt: |
| return LLVMFCmp::ogt; |
| case CIR::ge: |
| return LLVMFCmp::oge; |
| } |
| llvm_unreachable("Unknown CmpOpKind"); |
| } |
| |
| /// If the given type is a vector type, return the vector's element type. |
| /// Otherwise return the given type unchanged. |
| mlir::Type elementTypeIfVector(mlir::Type type) { |
| if (auto VecType = mlir::dyn_cast<mlir::cir::VectorType>(type)) { |
| return VecType.getEltType(); |
| } |
| return type; |
| } |
| |
| mlir::LLVM::Visibility |
| lowerCIRVisibilityToLLVMVisibility(mlir::cir::VisibilityKind visibilityKind) { |
| switch (visibilityKind) { |
| case mlir::cir::VisibilityKind::Default: |
| return ::mlir::LLVM::Visibility::Default; |
| case mlir::cir::VisibilityKind::Hidden: |
| return ::mlir::LLVM::Visibility::Hidden; |
| case mlir::cir::VisibilityKind::Protected: |
| return ::mlir::LLVM::Visibility::Protected; |
| } |
| } |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // Visitors for Lowering CIR Const Attributes |
| //===----------------------------------------------------------------------===// |
| |
| /// Switches on the type of attribute and calls the appropriate conversion. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter); |
| |
| /// IntAttr visitor. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::IntAttr intAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(intAttr.getType()), intAttr.getValue()); |
| } |
| |
| /// BoolAttr visitor. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::BoolAttr boolAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(boolAttr.getType()), boolAttr.getValue()); |
| } |
| |
| /// ConstPtrAttr visitor. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::ConstPtrAttr ptrAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto loc = parentOp->getLoc(); |
| if (ptrAttr.isNullValue()) { |
| return rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(ptrAttr.getType())); |
| } |
| mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>()); |
| mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())), |
| ptrAttr.getValue().getInt()); |
| return rewriter.create<mlir::LLVM::IntToPtrOp>( |
| loc, converter->convertType(ptrAttr.getType()), ptrVal); |
| } |
| |
| /// FPAttr visitor. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::FPAttr fltAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(fltAttr.getType()), fltAttr.getValue()); |
| } |
| |
| /// ZeroAttr visitor. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::cir::ZeroAttr zeroAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto loc = parentOp->getLoc(); |
| return rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(zeroAttr.getType())); |
| } |
| |
| /// ConstStruct visitor. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::ConstStructAttr constStruct, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto llvmTy = converter->convertType(constStruct.getType()); |
| auto loc = parentOp->getLoc(); |
| mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| |
| // Iteratively lower each constant element of the struct. |
| for (auto [idx, elt] : llvm::enumerate(constStruct.getMembers())) { |
| mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); |
| result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| |
| return result; |
| } |
| |
| // VTableAttr visitor. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::VTableAttr vtableArr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto llvmTy = converter->convertType(vtableArr.getType()); |
| auto loc = parentOp->getLoc(); |
| mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| |
| for (auto [idx, elt] : llvm::enumerate(vtableArr.getVtableData())) { |
| mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); |
| result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| |
| return result; |
| } |
| |
| // TypeInfoAttr visitor. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::TypeInfoAttr typeinfoArr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto llvmTy = converter->convertType(typeinfoArr.getType()); |
| auto loc = parentOp->getLoc(); |
| mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| |
| for (auto [idx, elt] : llvm::enumerate(typeinfoArr.getData())) { |
| mlir::Value init = lowerCirAttrAsValue(parentOp, elt, rewriter, converter); |
| result = rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| |
| return result; |
| } |
| |
| // ConstArrayAttr visitor |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::ConstArrayAttr constArr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto llvmTy = converter->convertType(constArr.getType()); |
| auto loc = parentOp->getLoc(); |
| mlir::Value result; |
| |
| if (auto zeros = constArr.getTrailingZerosNum()) { |
| auto arrayTy = constArr.getType(); |
| result = rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, converter->convertType(arrayTy)); |
| } else { |
| result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| } |
| |
| // Iteratively lower each constant element of the array. |
| if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(constArr.getElts())) { |
| for (auto [idx, elt] : llvm::enumerate(arrayAttr)) { |
| mlir::Value init = |
| lowerCirAttrAsValue(parentOp, elt, rewriter, converter); |
| result = |
| rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| } |
| // TODO(cir): this diverges from traditional lowering. Normally the string |
| // would be a global constant that is memcopied. |
| else if (auto strAttr = |
| mlir::dyn_cast<mlir::StringAttr>(constArr.getElts())) { |
| auto arrayTy = mlir::dyn_cast<mlir::cir::ArrayType>(strAttr.getType()); |
| assert(arrayTy && "String attribute must have an array type"); |
| auto eltTy = arrayTy.getEltType(); |
| for (auto [idx, elt] : llvm::enumerate(strAttr)) { |
| auto init = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, converter->convertType(eltTy), elt); |
| result = |
| rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx); |
| } |
| } else { |
| llvm_unreachable("unexpected ConstArrayAttr elements"); |
| } |
| |
| return result; |
| } |
| |
| // ConstVectorAttr visitor. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::ConstVectorAttr constVec, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto llvmTy = converter->convertType(constVec.getType()); |
| auto loc = parentOp->getLoc(); |
| SmallVector<mlir::Attribute> mlirValues; |
| for (auto elementAttr : constVec.getElts()) { |
| mlir::Attribute mlirAttr; |
| if (auto intAttr = mlir::dyn_cast<mlir::cir::IntAttr>(elementAttr)) { |
| mlirAttr = rewriter.getIntegerAttr( |
| converter->convertType(intAttr.getType()), intAttr.getValue()); |
| } else if (auto floatAttr = |
| mlir::dyn_cast<mlir::cir::FPAttr>(elementAttr)) { |
| mlirAttr = rewriter.getFloatAttr( |
| converter->convertType(floatAttr.getType()), floatAttr.getValue()); |
| } else { |
| llvm_unreachable( |
| "vector constant with an element that is neither an int nor a float"); |
| } |
| mlirValues.push_back(mlirAttr); |
| } |
| return rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmTy, |
| mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy), |
| mlirValues)); |
| } |
| |
| // GlobalViewAttr visitor. |
| mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, |
| mlir::cir::GlobalViewAttr globalAttr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| auto module = parentOp->getParentOfType<mlir::ModuleOp>(); |
| mlir::Type sourceType; |
| llvm::StringRef symName; |
| auto *sourceSymbol = |
| mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol()); |
| if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) { |
| sourceType = llvmSymbol.getType(); |
| symName = llvmSymbol.getSymName(); |
| } else if (auto cirSymbol = dyn_cast<mlir::cir::GlobalOp>(sourceSymbol)) { |
| sourceType = converter->convertType(cirSymbol.getSymType()); |
| symName = cirSymbol.getSymName(); |
| } else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) { |
| sourceType = llvmFun.getFunctionType(); |
| symName = llvmFun.getSymName(); |
| } else if (auto fun = dyn_cast<mlir::cir::FuncOp>(sourceSymbol)) { |
| sourceType = converter->convertType(fun.getFunctionType()); |
| symName = fun.getSymName(); |
| } else { |
| llvm_unreachable("Unexpected GlobalOp type"); |
| } |
| |
| auto loc = parentOp->getLoc(); |
| mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>( |
| loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), symName); |
| |
| if (globalAttr.getIndices()) { |
| llvm::SmallVector<mlir::LLVM::GEPArg> indices; |
| for (auto idx : globalAttr.getIndices()) { |
| auto intAttr = dyn_cast<mlir::IntegerAttr>(idx); |
| assert(intAttr && "index must be integers"); |
| indices.push_back(intAttr.getValue().getSExtValue()); |
| } |
| auto resTy = addrOp.getType(); |
| auto eltTy = converter->convertType(sourceType); |
| addrOp = rewriter.create<mlir::LLVM::GEPOp>(loc, resTy, eltTy, addrOp, |
| indices, true); |
| } |
| |
| auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(globalAttr.getType()); |
| assert(ptrTy && "Expecting pointer type in GlobalViewAttr"); |
| auto llvmEltTy = converter->convertType(ptrTy.getPointee()); |
| |
| if (llvmEltTy == sourceType) |
| return addrOp; |
| |
| auto llvmDstTy = converter->convertType(globalAttr.getType()); |
| return rewriter.create<mlir::LLVM::BitcastOp>(parentOp->getLoc(), llvmDstTy, |
| addrOp); |
| } |
| |
| /// Switches on the type of attribute and calls the appropriate conversion. |
| inline mlir::Value |
| lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter) { |
| if (const auto intAttr = mlir::dyn_cast<mlir::cir::IntAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter); |
| if (const auto fltAttr = mlir::dyn_cast<mlir::cir::FPAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter); |
| if (const auto ptrAttr = mlir::dyn_cast<mlir::cir::ConstPtrAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter); |
| if (const auto constStruct = mlir::dyn_cast<mlir::cir::ConstStructAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter); |
| if (const auto constArr = mlir::dyn_cast<mlir::cir::ConstArrayAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter); |
| if (const auto constVec = mlir::dyn_cast<mlir::cir::ConstVectorAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter); |
| if (const auto boolAttr = mlir::dyn_cast<mlir::cir::BoolAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter); |
| if (const auto zeroAttr = mlir::dyn_cast<mlir::cir::ZeroAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter); |
| if (const auto globalAttr = mlir::dyn_cast<mlir::cir::GlobalViewAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter); |
| if (const auto vtableAttr = mlir::dyn_cast<mlir::cir::VTableAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter); |
| if (const auto typeinfoAttr = mlir::dyn_cast<mlir::cir::TypeInfoAttr>(attr)) |
| return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter); |
| |
| llvm_unreachable("unhandled attribute type"); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| |
| mlir::LLVM::Linkage convertLinkage(mlir::cir::GlobalLinkageKind linkage) { |
| using CIR = mlir::cir::GlobalLinkageKind; |
| using LLVM = mlir::LLVM::Linkage; |
| |
| switch (linkage) { |
| case CIR::AvailableExternallyLinkage: |
| return LLVM::AvailableExternally; |
| case CIR::CommonLinkage: |
| return LLVM::Common; |
| case CIR::ExternalLinkage: |
| return LLVM::External; |
| case CIR::ExternalWeakLinkage: |
| return LLVM::ExternWeak; |
| case CIR::InternalLinkage: |
| return LLVM::Internal; |
| case CIR::LinkOnceAnyLinkage: |
| return LLVM::Linkonce; |
| case CIR::LinkOnceODRLinkage: |
| return LLVM::LinkonceODR; |
| case CIR::PrivateLinkage: |
| return LLVM::Private; |
| case CIR::WeakAnyLinkage: |
| return LLVM::Weak; |
| case CIR::WeakODRLinkage: |
| return LLVM::WeakODR; |
| }; |
| } |
| |
| mlir::LLVM::CConv convertCallingConv(mlir::cir::CallingConv callinvConv) { |
| using CIR = mlir::cir::CallingConv; |
| using LLVM = mlir::LLVM::CConv; |
| |
| switch (callinvConv) { |
| case CIR::C: |
| return LLVM::C; |
| case CIR::SpirKernel: |
| return LLVM::SPIR_KERNEL; |
| case CIR::SpirFunction: |
| return LLVM::SPIR_FUNC; |
| } |
| llvm_unreachable("Unknown calling convention"); |
| } |
| |
| class CIRCopyOpLowering : public mlir::OpConversionPattern<mlir::cir::CopyOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::CopyOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CopyOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| const mlir::Value length = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), op.getLength()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::MemcpyOp>( |
| op, adaptor.getDst(), adaptor.getSrc(), length, op.getIsVolatile()); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRMemCpyOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::MemCpyOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::MemCpyOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::MemCpyOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::MemcpyOp>( |
| op, adaptor.getDst(), adaptor.getSrc(), adaptor.getLen(), |
| /*isVolatile=*/false); |
| return mlir::success(); |
| } |
| }; |
| |
| static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter, |
| mlir::Value llvmSrc, |
| mlir::IntegerType llvmDstIntTy, |
| bool isUnsigned, uint64_t cirDstIntWidth) { |
| auto cirSrcWidth = |
| mlir::cast<mlir::IntegerType>(llvmSrc.getType()).getWidth(); |
| if (cirSrcWidth == cirDstIntWidth) |
| return llvmSrc; |
| |
| auto loc = llvmSrc.getLoc(); |
| if (cirSrcWidth < cirDstIntWidth) { |
| if (isUnsigned) |
| return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc); |
| return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc); |
| } |
| |
| // Otherwise truncate |
| return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc); |
| } |
| |
| class CIRPtrStrideOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::PtrStrideOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::PtrStrideOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto *tc = getTypeConverter(); |
| const auto resultTy = tc->convertType(ptrStrideOp.getType()); |
| auto elementTy = tc->convertType(ptrStrideOp.getElementTy()); |
| auto *ctx = elementTy.getContext(); |
| |
| // void and function types doesn't really have a layout to use in GEPs, |
| // make it i8 instead. |
| if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) || |
| mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy)) |
| elementTy = mlir::IntegerType::get(elementTy.getContext(), 8, |
| mlir::IntegerType::Signless); |
| |
| // Zero-extend, sign-extend or trunc the pointer value. |
| auto index = adaptor.getStride(); |
| auto width = mlir::cast<mlir::IntegerType>(index.getType()).getWidth(); |
| mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType<mlir::ModuleOp>()); |
| auto layoutWidth = |
| LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType()); |
| auto indexOp = index.getDefiningOp(); |
| if (indexOp && layoutWidth && width != *layoutWidth) { |
| // If the index comes from a subtraction, make sure the extension happens |
| // before it. To achieve that, look at unary minus, which already got |
| // lowered to "sub 0, x". |
| auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp); |
| auto unary = |
| dyn_cast<mlir::cir::UnaryOp>(ptrStrideOp.getStride().getDefiningOp()); |
| bool rewriteSub = |
| unary && unary.getKind() == mlir::cir::UnaryOpKind::Minus && sub; |
| if (rewriteSub) |
| index = indexOp->getOperand(1); |
| |
| // Handle the cast |
| auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth); |
| index = getLLVMIntCast(rewriter, index, llvmDstType, |
| ptrStrideOp.getStride().getType().isUnsigned(), |
| *layoutWidth); |
| |
| // Rewrite the sub in front of extensions/trunc |
| if (rewriteSub) { |
| index = rewriter.create<mlir::LLVM::SubOp>( |
| index.getLoc(), index.getType(), |
| rewriter.create<mlir::LLVM::ConstantOp>( |
| index.getLoc(), index.getType(), |
| mlir::IntegerAttr::get(index.getType(), 0)), |
| index); |
| sub->erase(); |
| } |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRBrCondOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BrCondOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::BrCondOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BrCondOp brOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Value i1Condition; |
| |
| auto hasOneUse = false; |
| |
| if (auto defOp = brOp.getCond().getDefiningOp()) |
| hasOneUse = defOp->getResult(0).hasOneUse(); |
| |
| if (auto defOp = adaptor.getCond().getDefiningOp()) { |
| if (auto zext = dyn_cast<mlir::LLVM::ZExtOp>(defOp)) { |
| if (zext->use_empty() && |
| zext->getOperand(0).getType() == rewriter.getI1Type()) { |
| i1Condition = zext->getOperand(0); |
| if (hasOneUse) |
| rewriter.eraseOp(zext); |
| } |
| } |
| } |
| |
| if (!i1Condition) |
| i1Condition = rewriter.create<mlir::LLVM::TruncOp>( |
| brOp.getLoc(), rewriter.getI1Type(), adaptor.getCond()); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>( |
| brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(), |
| brOp.getDestFalse(), adaptor.getDestOperandsFalse()); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRCastOpLowering : public mlir::OpConversionPattern<mlir::cir::CastOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::CastOp>::OpConversionPattern; |
| |
| inline mlir::Type convertTy(mlir::Type ty) const { |
| return getTypeConverter()->convertType(ty); |
| } |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CastOp castOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // For arithmetic conversions, LLVM IR uses the same instruction to convert |
| // both individual scalars and entire vectors. This lowering pass handles |
| // both situations. |
| |
| auto src = adaptor.getSrc(); |
| |
| switch (castOp.getKind()) { |
| case mlir::cir::CastKind::array_to_ptrdecay: { |
| const auto ptrTy = mlir::cast<mlir::cir::PointerType>(castOp.getType()); |
| auto sourceValue = adaptor.getOperands().front(); |
| auto targetType = convertTy(ptrTy); |
| auto elementTy = convertTy(ptrTy.getPointee()); |
| auto offset = llvm::SmallVector<mlir::LLVM::GEPArg>{0}; |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| castOp, targetType, elementTy, sourceValue, offset); |
| break; |
| } |
| case mlir::cir::CastKind::int_to_bool: { |
| auto zero = rewriter.create<mlir::cir::ConstantOp>( |
| src.getLoc(), castOp.getSrc().getType(), |
| mlir::cir::IntAttr::get(castOp.getSrc().getType(), 0)); |
| rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>( |
| castOp, mlir::cir::BoolType::get(getContext()), |
| mlir::cir::CmpOpKind::ne, castOp.getSrc(), zero); |
| break; |
| } |
| case mlir::cir::CastKind::integral: { |
| auto srcType = castOp.getSrc().getType(); |
| auto dstType = castOp.getResult().getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstType = getTypeConverter()->convertType(dstType); |
| mlir::cir::IntType srcIntType = |
| mlir::cast<mlir::cir::IntType>(elementTypeIfVector(srcType)); |
| mlir::cir::IntType dstIntType = |
| mlir::cast<mlir::cir::IntType>(elementTypeIfVector(dstType)); |
| rewriter.replaceOp( |
| castOp, |
| getLLVMIntCast(rewriter, llvmSrcVal, |
| mlir::cast<mlir::IntegerType>(llvmDstType), |
| srcIntType.isUnsigned(), dstIntType.getWidth())); |
| break; |
| } |
| case mlir::cir::CastKind::floating: { |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = |
| getTypeConverter()->convertType(castOp.getResult().getType()); |
| |
| auto srcTy = elementTypeIfVector(castOp.getSrc().getType()); |
| auto dstTy = elementTypeIfVector(castOp.getResult().getType()); |
| |
| if (!mlir::isa<mlir::cir::CIRFPTypeInterface>(dstTy) || |
| !mlir::isa<mlir::cir::CIRFPTypeInterface>(srcTy)) |
| return castOp.emitError() |
| << "NYI cast from " << srcTy << " to " << dstTy; |
| |
| auto getFloatWidth = [](mlir::Type ty) -> unsigned { |
| return mlir::cast<mlir::cir::CIRFPTypeInterface>(ty).getWidth(); |
| }; |
| |
| if (getFloatWidth(srcTy) > getFloatWidth(dstTy)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::int_to_ptr: { |
| auto dstTy = mlir::cast<mlir::cir::PointerType>(castOp.getType()); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::ptr_to_int: { |
| auto dstTy = mlir::cast<mlir::cir::IntType>(castOp.getType()); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::float_to_bool: { |
| auto dstTy = mlir::cast<mlir::cir::BoolType>(castOp.getType()); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| auto kind = mlir::LLVM::FCmpPredicate::une; |
| |
| // Check if float is not equal to zero. |
| auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>( |
| castOp.getLoc(), llvmSrcVal.getType(), |
| mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0)); |
| |
| // Extend comparison result to either bool (C++) or int (C). |
| mlir::Value cmpResult = rewriter.create<mlir::LLVM::FCmpOp>( |
| castOp.getLoc(), kind, llvmSrcVal, zeroFloat); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy, |
| cmpResult); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::bool_to_int: { |
| auto dstTy = mlir::cast<mlir::cir::IntType>(castOp.getType()); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType()); |
| auto llvmDstTy = |
| mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy)); |
| if (llvmSrcTy.getWidth() == llvmDstTy.getWidth()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::bool_to_float: { |
| auto dstTy = castOp.getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::int_to_float: { |
| auto dstTy = castOp.getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| if (mlir::cast<mlir::cir::IntType>( |
| elementTypeIfVector(castOp.getSrc().getType())) |
| .isSigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::float_to_int: { |
| auto dstTy = castOp.getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| if (mlir::cast<mlir::cir::IntType>( |
| elementTypeIfVector(castOp.getResult().getType())) |
| .isSigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::bitcast: { |
| auto dstTy = castOp.getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy, |
| llvmSrcVal); |
| return mlir::success(); |
| } |
| case mlir::cir::CastKind::ptr_to_bool: { |
| auto zero = |
| mlir::IntegerAttr::get(mlir::IntegerType::get(getContext(), 64), 0); |
| auto null = rewriter.create<mlir::cir::ConstantOp>( |
| src.getLoc(), castOp.getSrc().getType(), |
| mlir::cir::ConstPtrAttr::get(getContext(), castOp.getSrc().getType(), |
| zero)); |
| rewriter.replaceOpWithNewOp<mlir::cir::CmpOp>( |
| castOp, mlir::cir::BoolType::get(getContext()), |
| mlir::cir::CmpOpKind::ne, castOp.getSrc(), null); |
| break; |
| } |
| case mlir::cir::CastKind::address_space: { |
| auto dstTy = castOp.getType(); |
| auto llvmSrcVal = adaptor.getOperands().front(); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>( |
| castOp, llvmDstTy, llvmSrcVal); |
| break; |
| } |
| default: { |
| return castOp.emitError("Unhandled cast kind: ") |
| << castOp.getKindAttrName(); |
| } |
| } |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRReturnLowering |
| : public mlir::OpConversionPattern<mlir::cir::ReturnOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ReturnOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ReturnOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, |
| adaptor.getOperands()); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| struct ConvertCIRToLLVMPass |
| : public mlir::PassWrapper<ConvertCIRToLLVMPass, |
| mlir::OperationPass<mlir::ModuleOp>> { |
| void getDependentDialects(mlir::DialectRegistry ®istry) const override { |
| registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect, |
| mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>(); |
| } |
| void runOnOperation() final; |
| |
| virtual StringRef getArgument() const override { return "cir-flat-to-llvm"; } |
| }; |
| |
| mlir::LogicalResult |
| rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, |
| mlir::ConversionPatternRewriter &rewriter, |
| const mlir::TypeConverter *converter, |
| mlir::FlatSymbolRefAttr calleeAttr, |
| mlir::Block *continueBlock = nullptr, |
| mlir::Block *landingPadBlock = nullptr) { |
| llvm::SmallVector<mlir::Type, 8> llvmResults; |
| auto cirResults = op->getResultTypes(); |
| |
| if (converter->convertTypes(cirResults, llvmResults).failed()) |
| return mlir::failure(); |
| |
| if (calleeAttr) { // direct call |
| if (landingPadBlock) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>( |
| op, llvmResults, calleeAttr, callOperands, continueBlock, |
| mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{}); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmResults, |
| calleeAttr, callOperands); |
| } else { // indirect call |
| assert(op->getOperands().size() && |
| "operands list must no be empty for the indirect call"); |
| auto typ = op->getOperands().front().getType(); |
| assert(isa<mlir::cir::PointerType>(typ) && "expected pointer type"); |
| auto ptyp = dyn_cast<mlir::cir::PointerType>(typ); |
| auto ftyp = dyn_cast<mlir::cir::FuncType>(ptyp.getPointee()); |
| assert(ftyp && "expected a pointer to a function as the first operand"); |
| |
| if (landingPadBlock) { |
| auto llvmFnTy = |
| dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>( |
| op, llvmFnTy, mlir::FlatSymbolRefAttr{}, callOperands, continueBlock, |
| mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{}); |
| } else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| op, |
| dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)), |
| callOperands); |
| } |
| return mlir::success(); |
| } |
| |
| class CIRCallLowering : public mlir::OpConversionPattern<mlir::cir::CallOp> { |
| public: |
| using OpConversionPattern<mlir::cir::CallOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CallOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| return rewriteToCallOrInvoke(op.getOperation(), adaptor.getOperands(), |
| rewriter, getTypeConverter(), |
| op.getCalleeAttr()); |
| } |
| }; |
| |
| class CIRTryCallLowering |
| : public mlir::OpConversionPattern<mlir::cir::TryCallOp> { |
| public: |
| using OpConversionPattern<mlir::cir::TryCallOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::TryCallOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| return rewriteToCallOrInvoke( |
| op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(), |
| op.getCalleeAttr(), op.getCont(), op.getLandingPad()); |
| } |
| }; |
| |
| static mlir::LLVM::LLVMStructType |
| getLLVMLandingPadStructTy(mlir::ConversionPatternRewriter &rewriter) { |
| // Create the landing pad type: struct { ptr, i32 } |
| mlir::MLIRContext *ctx = rewriter.getContext(); |
| auto llvmPtr = mlir::LLVM::LLVMPointerType::get(ctx); |
| llvm::SmallVector<mlir::Type> structFields; |
| structFields.push_back(llvmPtr); |
| structFields.push_back(rewriter.getI32Type()); |
| |
| return mlir::LLVM::LLVMStructType::getLiteral(ctx, structFields); |
| } |
| |
| class CIREhInflightOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::EhInflightOp> { |
| public: |
| using OpConversionPattern<mlir::cir::EhInflightOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::EhInflightOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Location loc = op.getLoc(); |
| auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); |
| mlir::ArrayAttr symListAttr = op.getSymTypeListAttr(); |
| mlir::SmallVector<mlir::Value, 4> symAddrs; |
| |
| auto llvmFn = op->getParentOfType<mlir::LLVM::LLVMFuncOp>(); |
| assert(llvmFn && "expected LLVM function parent"); |
| mlir::Block *entryBlock = &llvmFn.getRegion().front(); |
| assert(entryBlock->isEntryBlock()); |
| |
| // %x = landingpad { ptr, i32 } |
| // Note that since llvm.landingpad has to be the first operation on the |
| // block, any needed value for its operands has to be added somewhere else. |
| if (symListAttr) { |
| // catch ptr @_ZTIi |
| // catch ptr @_ZTIPKc |
| for (mlir::Attribute attr : op.getSymTypeListAttr()) { |
| auto symAttr = cast<mlir::FlatSymbolRefAttr>(attr); |
| // Generate `llvm.mlir.addressof` for each symbol, and place those |
| // operations in the LLVM function entry basic block. |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(entryBlock); |
| mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>( |
| loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), |
| symAttr.getValue()); |
| symAddrs.push_back(addrOp); |
| } |
| } else { |
| if (!op.getCleanup()) { |
| // catch ptr null |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(entryBlock); |
| mlir::Value nullOp = rewriter.create<mlir::LLVM::ZeroOp>( |
| loc, mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); |
| symAddrs.push_back(nullOp); |
| } |
| } |
| |
| // %slot = extractvalue { ptr, i32 } %x, 0 |
| // %selector = extractvalue { ptr, i32 } %x, 1 |
| auto padOp = rewriter.create<mlir::LLVM::LandingpadOp>( |
| loc, llvmLandingPadStructTy, symAddrs); |
| SmallVector<int64_t> slotIdx = {0}; |
| SmallVector<int64_t> selectorIdx = {1}; |
| |
| if (op.getCleanup()) |
| padOp.setCleanup(true); |
| |
| mlir::Value slot = |
| rewriter.create<mlir::LLVM::ExtractValueOp>(loc, padOp, slotIdx); |
| mlir::Value selector = |
| rewriter.create<mlir::LLVM::ExtractValueOp>(loc, padOp, selectorIdx); |
| |
| rewriter.replaceOp(op, mlir::ValueRange{slot, selector}); |
| |
| // Landing pads are required to be in LLVM functions with personality |
| // attribute. FIXME: for now hardcode personality creation in order to start |
| // adding exception tests, once we annotate CIR with such information, |
| // change it to be in FuncOp lowering instead. |
| { |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| // Insert personality decl before the current function. |
| rewriter.setInsertionPoint(llvmFn); |
| auto personalityFnTy = |
| mlir::LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {}, |
| /*isVarArg=*/true); |
| auto personalityFn = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| loc, "__gxx_personality_v0", personalityFnTy); |
| llvmFn.setPersonality(personalityFn.getName()); |
| } |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRAllocaLowering |
| : public mlir::OpConversionPattern<mlir::cir::AllocaOp> { |
| mlir::DataLayout const &dataLayout; |
| |
| public: |
| CIRAllocaLowering(mlir::TypeConverter const &typeConverter, |
| mlir::DataLayout const &dataLayout, |
| mlir::MLIRContext *context) |
| : OpConversionPattern<mlir::cir::AllocaOp>(typeConverter, context), |
| dataLayout(dataLayout) {} |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::AllocaOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Value size = |
| op.isDynamic() |
| ? adaptor.getDynAllocSize() |
| : rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), |
| typeConverter->convertType(rewriter.getIndexType()), |
| rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); |
| auto elementTy = getTypeConverter()->convertType(op.getAllocaType()); |
| auto resultTy = getTypeConverter()->convertType(op.getResult().getType()); |
| // Verification between the CIR alloca AS and the one from data layout. |
| { |
| auto resPtrTy = mlir::cast<mlir::LLVM::LLVMPointerType>(resultTy); |
| auto dlAllocaASAttr = mlir::cast_if_present<mlir::IntegerAttr>( |
| dataLayout.getAllocaMemorySpace()); |
| // Absence means 0 |
| // TODO: The query for the alloca AS should be done through CIRDataLayout |
| // instead to reuse the logic of interpret null attr as 0. |
| auto dlAllocaAS = dlAllocaASAttr ? dlAllocaASAttr.getInt() : 0; |
| if (dlAllocaAS != resPtrTy.getAddressSpace()) { |
| return op.emitError() << "alloca address space doesn't match the one " |
| "from the target data layout: " |
| << dlAllocaAS; |
| } |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>( |
| op, resultTy, elementTy, size, op.getAlignmentAttr().getInt()); |
| return mlir::success(); |
| } |
| }; |
| |
| static mlir::LLVM::AtomicOrdering |
| getLLVMMemOrder(std::optional<mlir::cir::MemOrder> &memorder) { |
| if (!memorder) |
| return mlir::LLVM::AtomicOrdering::not_atomic; |
| switch (*memorder) { |
| case mlir::cir::MemOrder::Relaxed: |
| return mlir::LLVM::AtomicOrdering::monotonic; |
| case mlir::cir::MemOrder::Consume: |
| case mlir::cir::MemOrder::Acquire: |
| return mlir::LLVM::AtomicOrdering::acquire; |
| case mlir::cir::MemOrder::Release: |
| return mlir::LLVM::AtomicOrdering::release; |
| case mlir::cir::MemOrder::AcquireRelease: |
| return mlir::LLVM::AtomicOrdering::acq_rel; |
| case mlir::cir::MemOrder::SequentiallyConsistent: |
| return mlir::LLVM::AtomicOrdering::seq_cst; |
| } |
| llvm_unreachable("unknown memory order"); |
| } |
| |
| class CIRLoadLowering : public mlir::OpConversionPattern<mlir::cir::LoadOp> { |
| public: |
| using OpConversionPattern<mlir::cir::LoadOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::LoadOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| const auto llvmTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| auto memorder = op.getMemOrder(); |
| auto ordering = getLLVMMemOrder(memorder); |
| auto alignOpt = op.getAlignment(); |
| unsigned alignment = 0; |
| if (!alignOpt) { |
| mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>()); |
| alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); |
| } else { |
| alignment = *alignOpt; |
| } |
| |
| // TODO: nontemporal, invariant, syncscope. |
| rewriter.replaceOpWithNewOp<mlir::LLVM::LoadOp>( |
| op, llvmTy, adaptor.getAddr(), /* alignment */ alignment, |
| op.getIsVolatile(), /* nontemporal */ false, |
| /* invariant */ false, ordering); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRStoreLowering : public mlir::OpConversionPattern<mlir::cir::StoreOp> { |
| public: |
| using OpConversionPattern<mlir::cir::StoreOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::StoreOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto memorder = op.getMemOrder(); |
| auto ordering = getLLVMMemOrder(memorder); |
| auto alignOpt = op.getAlignment(); |
| unsigned alignment = 0; |
| if (!alignOpt) { |
| const auto llvmTy = |
| getTypeConverter()->convertType(op.getValue().getType()); |
| mlir::DataLayout layout(op->getParentOfType<mlir::ModuleOp>()); |
| alignment = (unsigned)layout.getTypeABIAlignment(llvmTy); |
| } else { |
| alignment = *alignOpt; |
| } |
| |
| // TODO: nontemporal, syncscope. |
| rewriter.replaceOpWithNewOp<mlir::LLVM::StoreOp>( |
| op, adaptor.getValue(), adaptor.getAddr(), alignment, |
| op.getIsVolatile(), /* nontemporal */ false, ordering); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| bool hasTrailingZeros(mlir::cir::ConstArrayAttr attr) { |
| auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts()); |
| return attr.hasTrailingZeros() || |
| (array && std::count_if(array.begin(), array.end(), [](auto elt) { |
| auto ar = dyn_cast<mlir::cir::ConstArrayAttr>(elt); |
| return ar && hasTrailingZeros(ar); |
| })); |
| } |
| |
| static mlir::Attribute |
| lowerDataMemberAttr(mlir::ModuleOp moduleOp, mlir::cir::DataMemberAttr attr, |
| const mlir::TypeConverter &typeConverter) { |
| mlir::DataLayout layout{moduleOp}; |
| |
| uint64_t memberOffset; |
| if (attr.isNullPtr()) { |
| // TODO(cir): the numerical value of a null data member pointer is |
| // ABI-specific and should be queried through ABI. |
| assert(!MissingFeatures::targetCodeGenInfoGetNullPointer()); |
| memberOffset = -1ull; |
| } else { |
| auto memberIndex = attr.getMemberIndex().value(); |
| memberOffset = |
| attr.getType().getClsTy().getElementOffset(layout, memberIndex); |
| } |
| |
| auto underlyingIntTy = mlir::IntegerType::get( |
| moduleOp->getContext(), layout.getTypeSizeInBits(attr.getType())); |
| return mlir::IntegerAttr::get(underlyingIntTy, memberOffset); |
| } |
| |
| class CIRConstantLowering |
| : public mlir::OpConversionPattern<mlir::cir::ConstantOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ConstantOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ConstantOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Attribute attr = op.getValue(); |
| |
| if (mlir::isa<mlir::cir::BoolType>(op.getType())) { |
| int value = |
| (op.getValue() == |
| mlir::cir::BoolAttr::get( |
| getContext(), ::mlir::cir::BoolType::get(getContext()), true)); |
| attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()), |
| value); |
| } else if (mlir::isa<mlir::cir::IntType>(op.getType())) { |
| attr = rewriter.getIntegerAttr( |
| typeConverter->convertType(op.getType()), |
| mlir::cast<mlir::cir::IntAttr>(op.getValue()).getValue()); |
| } else if (mlir::isa<mlir::cir::CIRFPTypeInterface>(op.getType())) { |
| attr = rewriter.getFloatAttr( |
| typeConverter->convertType(op.getType()), |
| mlir::cast<mlir::cir::FPAttr>(op.getValue()).getValue()); |
| } else if (auto complexTy = |
| mlir::dyn_cast<mlir::cir::ComplexType>(op.getType())) { |
| auto complexAttr = mlir::cast<mlir::cir::ComplexAttr>(op.getValue()); |
| auto complexElemTy = complexTy.getElementTy(); |
| auto complexElemLLVMTy = typeConverter->convertType(complexElemTy); |
| |
| mlir::Attribute components[2]; |
| if (mlir::isa<mlir::cir::IntType>(complexElemTy)) { |
| components[0] = rewriter.getIntegerAttr( |
| complexElemLLVMTy, |
| mlir::cast<mlir::cir::IntAttr>(complexAttr.getReal()).getValue()); |
| components[1] = rewriter.getIntegerAttr( |
| complexElemLLVMTy, |
| mlir::cast<mlir::cir::IntAttr>(complexAttr.getImag()).getValue()); |
| } else { |
| components[0] = rewriter.getFloatAttr( |
| complexElemLLVMTy, |
| mlir::cast<mlir::cir::FPAttr>(complexAttr.getReal()).getValue()); |
| components[1] = rewriter.getFloatAttr( |
| complexElemLLVMTy, |
| mlir::cast<mlir::cir::FPAttr>(complexAttr.getImag()).getValue()); |
| } |
| |
| attr = rewriter.getArrayAttr(components); |
| } else if (mlir::isa<mlir::cir::PointerType>(op.getType())) { |
| // Optimize with dedicated LLVM op for null pointers. |
| if (mlir::isa<mlir::cir::ConstPtrAttr>(op.getValue())) { |
| if (mlir::cast<mlir::cir::ConstPtrAttr>(op.getValue()).isNullValue()) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>( |
| op, typeConverter->convertType(op.getType())); |
| return mlir::success(); |
| } |
| } |
| // Lower GlobalViewAttr to llvm.mlir.addressof |
| if (auto gv = mlir::dyn_cast<mlir::cir::GlobalViewAttr>(op.getValue())) { |
| auto newOp = lowerCirAttrAsValue(op, gv, rewriter, getTypeConverter()); |
| rewriter.replaceOp(op, newOp); |
| return mlir::success(); |
| } |
| attr = op.getValue(); |
| } else if (mlir::isa<mlir::cir::DataMemberType>(op.getType())) { |
| auto dataMember = mlir::cast<mlir::cir::DataMemberAttr>(op.getValue()); |
| attr = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(), |
| dataMember, *typeConverter); |
| } |
| // TODO(cir): constant arrays are currently just pushed into the stack using |
| // the store instruction, instead of being stored as global variables and |
| // then memcopyied into the stack (as done in Clang). |
| else if (auto arrTy = mlir::dyn_cast<mlir::cir::ArrayType>(op.getType())) { |
| // Fetch operation constant array initializer. |
| |
| auto constArr = mlir::dyn_cast<mlir::cir::ConstArrayAttr>(op.getValue()); |
| if (!constArr && !isa<mlir::cir::ZeroAttr>(op.getValue())) |
| return op.emitError() << "array does not have a constant initializer"; |
| |
| std::optional<mlir::Attribute> denseAttr; |
| if (constArr && hasTrailingZeros(constArr)) { |
| auto newOp = |
| lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter()); |
| rewriter.replaceOp(op, newOp); |
| return mlir::success(); |
| } else if (constArr && |
| (denseAttr = lowerConstArrayAttr(constArr, typeConverter))) { |
| attr = denseAttr.value(); |
| } else { |
| auto initVal = |
| lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter); |
| rewriter.replaceAllUsesWith(op, initVal); |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| } else if (const auto structAttr = |
| mlir::dyn_cast<mlir::cir::ConstStructAttr>(op.getValue())) { |
| // TODO(cir): this diverges from traditional lowering. Normally the |
| // initializer would be a global constant that is memcopied. Here we just |
| // define a local constant with llvm.undef that will be stored into the |
| // stack. |
| auto initVal = |
| lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter); |
| rewriter.replaceAllUsesWith(op, initVal); |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } else if (auto strTy = |
| mlir::dyn_cast<mlir::cir::StructType>(op.getType())) { |
| if (auto zero = mlir::dyn_cast<mlir::cir::ZeroAttr>(op.getValue())) { |
| auto initVal = lowerCirAttrAsValue(op, zero, rewriter, typeConverter); |
| rewriter.replaceAllUsesWith(op, initVal); |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| |
| return op.emitError() << "unsupported lowering for struct constant type " |
| << op.getType(); |
| } else if (const auto vecTy = |
| mlir::dyn_cast<mlir::cir::VectorType>(op.getType())) { |
| rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter, |
| getTypeConverter())); |
| return mlir::success(); |
| } else |
| return op.emitError() << "unsupported constant type " << op.getType(); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>( |
| op, getTypeConverter()->convertType(op.getType()), attr); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorCreateLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecCreateOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VecCreateOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecCreateOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Start with an 'undef' value for the vector. Then 'insertelement' for |
| // each of the vector elements. |
| auto vecTy = mlir::dyn_cast<mlir::cir::VectorType>(op.getType()); |
| assert(vecTy && "result type of cir.vec.create op is not VectorType"); |
| auto llvmTy = typeConverter->convertType(vecTy); |
| auto loc = op.getLoc(); |
| mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| assert(vecTy.getSize() == op.getElements().size() && |
| "cir.vec.create op count doesn't match vector type elements count"); |
| for (uint64_t i = 0; i < vecTy.getSize(); ++i) { |
| mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI64Type(), i); |
| result = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, result, adaptor.getElements()[i], indexValue); |
| } |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorCmpOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecCmpOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecCmpOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| assert(mlir::isa<mlir::cir::VectorType>(op.getType()) && |
| mlir::isa<mlir::cir::VectorType>(op.getLhs().getType()) && |
| mlir::isa<mlir::cir::VectorType>(op.getRhs().getType()) && |
| "Vector compare with non-vector type"); |
| // LLVM IR vector comparison returns a vector of i1. This one-bit vector |
| // must be sign-extended to the correct result type. |
| auto elementType = elementTypeIfVector(op.getLhs().getType()); |
| mlir::Value bitResult; |
| if (auto intType = mlir::dyn_cast<mlir::cir::IntType>(elementType)) { |
| bitResult = rewriter.create<mlir::LLVM::ICmpOp>( |
| op.getLoc(), |
| convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()), |
| adaptor.getLhs(), adaptor.getRhs()); |
| } else if (mlir::isa<mlir::cir::CIRFPTypeInterface>(elementType)) { |
| bitResult = rewriter.create<mlir::LLVM::FCmpOp>( |
| op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()), |
| adaptor.getLhs(), adaptor.getRhs()); |
| } else { |
| return op.emitError() << "unsupported type for VecCmpOp: " << elementType; |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>( |
| op, typeConverter->convertType(op.getType()), bitResult); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorSplatLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecSplatOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VecSplatOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecSplatOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Vector splat can be implemented with an `insertelement` and a |
| // `shufflevector`, which is better than an `insertelement` for each |
| // element in the vector. Start with an undef vector. Insert the value into |
| // the first element. Then use a `shufflevector` with a mask of all 0 to |
| // fill out the entire vector with that value. |
| auto vecTy = mlir::dyn_cast<mlir::cir::VectorType>(op.getType()); |
| assert(vecTy && "result type of cir.vec.splat op is not VectorType"); |
| auto llvmTy = typeConverter->convertType(vecTy); |
| auto loc = op.getLoc(); |
| mlir::Value undef = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy); |
| mlir::Value indexValue = |
| rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0); |
| mlir::Value elementValue = adaptor.getValue(); |
| mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, undef, elementValue, indexValue); |
| SmallVector<int32_t> zeroValues(vecTy.getSize(), 0); |
| mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>( |
| loc, oneElement, undef, zeroValues); |
| rewriter.replaceOp(op, shuffled); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorTernaryLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecTernaryOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VecTernaryOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecTernaryOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| assert(mlir::isa<mlir::cir::VectorType>(op.getType()) && |
| mlir::isa<mlir::cir::VectorType>(op.getCond().getType()) && |
| mlir::isa<mlir::cir::VectorType>(op.getVec1().getType()) && |
| mlir::isa<mlir::cir::VectorType>(op.getVec2().getType()) && |
| "Vector ternary op with non-vector type"); |
| // Convert `cond` into a vector of i1, then use that in a `select` op. |
| mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>( |
| op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(), |
| rewriter.create<mlir::LLVM::ZeroOp>( |
| op.getCond().getLoc(), |
| typeConverter->convertType(op.getCond().getType()))); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( |
| op, bitVec, adaptor.getVec1(), adaptor.getVec2()); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorShuffleIntsLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecShuffleOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VecShuffleOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecShuffleOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices. |
| // Convert the ClangIR ArrayAttr of IntAttr constants into a |
| // SmallVector<int>. |
| SmallVector<int, 8> indices; |
| std::transform(op.getIndices().begin(), op.getIndices().end(), |
| std::back_inserter(indices), [](mlir::Attribute intAttr) { |
| return mlir::cast<mlir::cir::IntAttr>(intAttr) |
| .getValue() |
| .getSExtValue(); |
| }); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>( |
| op, adaptor.getVec1(), adaptor.getVec2(), indices); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVectorShuffleVecLowering |
| : public mlir::OpConversionPattern<mlir::cir::VecShuffleDynamicOp> { |
| public: |
| using OpConversionPattern< |
| mlir::cir::VecShuffleDynamicOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VecShuffleDynamicOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // LLVM IR does not have an operation that corresponds to this form of |
| // the built-in. |
| // __builtin_shufflevector(V, I) |
| // is implemented as this pseudocode, where the for loop is unrolled |
| // and N is the number of elements: |
| // masked = I & (N-1) |
| // for (i in 0 <= i < N) |
| // result[i] = V[masked[i]] |
| auto loc = op.getLoc(); |
| mlir::Value input = adaptor.getVec(); |
| mlir::Type llvmIndexVecType = |
| getTypeConverter()->convertType(op.getIndices().getType()); |
| mlir::Type llvmIndexType = getTypeConverter()->convertType( |
| elementTypeIfVector(op.getIndices().getType())); |
| uint64_t numElements = |
| mlir::cast<mlir::cir::VectorType>(op.getVec().getType()).getSize(); |
| mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmIndexType, |
| mlir::IntegerAttr::get(llvmIndexType, numElements - 1)); |
| mlir::Value maskVector = |
| rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType); |
| for (uint64_t i = 0; i < numElements; ++i) { |
| mlir::Value iValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI64Type(), i); |
| maskVector = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, maskVector, maskValue, iValue); |
| } |
| mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>( |
| loc, llvmIndexVecType, adaptor.getIndices(), maskVector); |
| mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>( |
| loc, getTypeConverter()->convertType(op.getVec().getType())); |
| for (uint64_t i = 0; i < numElements; ++i) { |
| mlir::Value iValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI64Type(), i); |
| mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>( |
| loc, maskedIndices, iValue); |
| mlir::Value valueAtIndex = |
| rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue); |
| result = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, result, valueAtIndex, iValue); |
| } |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVAStartLowering |
| : public mlir::OpConversionPattern<mlir::cir::VAStartOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VAStartOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VAStartOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); |
| auto vaList = rewriter.create<mlir::LLVM::BitcastOp>( |
| op.getLoc(), opaquePtr, adaptor.getOperands().front()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::VaStartOp>(op, vaList); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVAEndLowering : public mlir::OpConversionPattern<mlir::cir::VAEndOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VAEndOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VAEndOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); |
| auto vaList = rewriter.create<mlir::LLVM::BitcastOp>( |
| op.getLoc(), opaquePtr, adaptor.getOperands().front()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::VaEndOp>(op, vaList); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVACopyLowering |
| : public mlir::OpConversionPattern<mlir::cir::VACopyOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VACopyOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VACopyOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto opaquePtr = mlir::LLVM::LLVMPointerType::get(getContext()); |
| auto dstList = rewriter.create<mlir::LLVM::BitcastOp>( |
| op.getLoc(), opaquePtr, adaptor.getOperands().front()); |
| auto srcList = rewriter.create<mlir::LLVM::BitcastOp>( |
| op.getLoc(), opaquePtr, adaptor.getOperands().back()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::VaCopyOp>(op, dstList, srcList); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVAArgLowering : public mlir::OpConversionPattern<mlir::cir::VAArgOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VAArgOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VAArgOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| return op.emitError("cir.vaarg lowering is NYI"); |
| } |
| }; |
| |
| class CIRFuncLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> { |
| public: |
| using OpConversionPattern<mlir::cir::FuncOp>::OpConversionPattern; |
| |
| /// Returns the name used for the linkage attribute. This *must* correspond |
| /// to the name of the attribute in ODS. |
| static StringRef getLinkageAttrNameString() { return "linkage"; } |
| |
| /// Convert the `cir.func` attributes to `llvm.func` attributes. |
| /// Only retain those attributes that are not constructed by |
| /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out |
| /// argument attributes. |
| void |
| lowerFuncAttributes(mlir::cir::FuncOp func, bool filterArgAndResAttrs, |
| SmallVectorImpl<mlir::NamedAttribute> &result) const { |
| for (auto attr : func->getAttrs()) { |
| if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() || |
| attr.getName() == func.getFunctionTypeAttrName() || |
| attr.getName() == getLinkageAttrNameString() || |
| attr.getName() == func.getCallingConvAttrName() || |
| (filterArgAndResAttrs && |
| (attr.getName() == func.getArgAttrsAttrName() || |
| attr.getName() == func.getResAttrsAttrName()))) |
| continue; |
| |
| // `CIRDialectLLVMIRTranslationInterface` requires "cir." prefix for |
| // dialect specific attributes, rename them. |
| if (attr.getName() == func.getExtraAttrsAttrName()) { |
| std::string cirName = "cir." + func.getExtraAttrsAttrName().str(); |
| attr.setName(mlir::StringAttr::get(getContext(), cirName)); |
| |
| lowerFuncOpenCLKernelMetadata(attr); |
| } |
| result.push_back(attr); |
| } |
| } |
| |
| /// When do module translation, we can only translate LLVM-compatible types. |
| /// Here we lower possible OpenCLKernelMetadataAttr to use the converted type. |
| void |
| lowerFuncOpenCLKernelMetadata(mlir::NamedAttribute &extraAttrsEntry) const { |
| const auto attrKey = mlir::cir::OpenCLKernelMetadataAttr::getMnemonic(); |
| auto oldExtraAttrs = |
| cast<mlir::cir::ExtraFuncAttributesAttr>(extraAttrsEntry.getValue()); |
| if (!oldExtraAttrs.getElements().contains(attrKey)) |
| return; |
| |
| mlir::NamedAttrList newExtraAttrs; |
| for (auto entry : oldExtraAttrs.getElements()) { |
| if (entry.getName() == attrKey) { |
| auto clKernelMetadata = |
| cast<mlir::cir::OpenCLKernelMetadataAttr>(entry.getValue()); |
| if (auto vecTypeHint = clKernelMetadata.getVecTypeHint()) { |
| auto newType = typeConverter->convertType(vecTypeHint.getValue()); |
| auto newTypeHint = mlir::TypeAttr::get(newType); |
| auto newCLKMAttr = mlir::cir::OpenCLKernelMetadataAttr::get( |
| getContext(), clKernelMetadata.getWorkGroupSizeHint(), |
| clKernelMetadata.getReqdWorkGroupSize(), newTypeHint, |
| clKernelMetadata.getVecTypeHintSignedness(), |
| clKernelMetadata.getIntelReqdSubGroupSize()); |
| entry.setValue(newCLKMAttr); |
| } |
| } |
| newExtraAttrs.push_back(entry); |
| } |
| extraAttrsEntry.setValue(mlir::cir::ExtraFuncAttributesAttr::get( |
| getContext(), newExtraAttrs.getDictionary(getContext()))); |
| } |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::FuncOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| |
| auto fnType = op.getFunctionType(); |
| auto isDsoLocal = op.getDsolocal(); |
| mlir::TypeConverter::SignatureConversion signatureConversion( |
| fnType.getNumInputs()); |
| |
| for (const auto &argType : enumerate(fnType.getInputs())) { |
| auto convertedType = typeConverter->convertType(argType.value()); |
| if (!convertedType) |
| return mlir::failure(); |
| signatureConversion.addInputs(argType.index(), convertedType); |
| } |
| |
| mlir::Type resultType = |
| getTypeConverter()->convertType(fnType.getReturnType()); |
| |
| // Create the LLVM function operation. |
| auto llvmFnTy = mlir::LLVM::LLVMFunctionType::get( |
| resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()), |
| signatureConversion.getConvertedTypes(), |
| /*isVarArg=*/fnType.isVarArg()); |
| // LLVMFuncOp expects a single FileLine Location instead of a fused |
| // location. |
| auto Loc = op.getLoc(); |
| if (mlir::isa<mlir::FusedLoc>(Loc)) { |
| auto FusedLoc = mlir::cast<mlir::FusedLoc>(Loc); |
| Loc = FusedLoc.getLocations()[0]; |
| } |
| assert((mlir::isa<mlir::FileLineColLoc>(Loc) || |
| mlir::isa<mlir::UnknownLoc>(Loc)) && |
| "expected single location or unknown location here"); |
| |
| auto linkage = convertLinkage(op.getLinkage()); |
| auto cconv = convertCallingConv(op.getCallingConv()); |
| SmallVector<mlir::NamedAttribute, 4> attributes; |
| lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes); |
| |
| auto fn = rewriter.create<mlir::LLVM::LLVMFuncOp>( |
| Loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv, |
| mlir::SymbolRefAttr(), attributes); |
| |
| fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get( |
| getContext(), lowerCIRVisibilityToLLVMVisibility( |
| op.getGlobalVisibilityAttr().getValue()))); |
| |
| rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end()); |
| if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter, |
| &signatureConversion))) |
| return mlir::failure(); |
| |
| rewriter.eraseOp(op); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRGetGlobalOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::GetGlobalOp> { |
| public: |
| using OpConversionPattern<mlir::cir::GetGlobalOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::GetGlobalOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // FIXME(cir): Premature DCE to avoid lowering stuff we're not using. |
| // CIRGen should mitigate this and not emit the get_global. |
| if (op->getUses().empty()) { |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| |
| auto type = getTypeConverter()->convertType(op.getType()); |
| auto symbol = op.getName(); |
| mlir::Operation *newop = |
| rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, symbol); |
| |
| if (op.getTls()) { |
| // Handle access to TLS via intrinsic. |
| newop = rewriter.create<mlir::LLVM::ThreadlocalAddressOp>( |
| op.getLoc(), type, newop->getResult(0)); |
| } |
| |
| rewriter.replaceOp(op, newop); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRComplexCreateOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ComplexCreateOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ComplexCreateOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ComplexCreateOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto complexLLVMTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| auto initialComplex = |
| rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy); |
| |
| int64_t position[1]{0}; |
| auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>( |
| op->getLoc(), initialComplex, adaptor.getReal(), position); |
| |
| position[0] = 1; |
| auto complex = rewriter.create<mlir::LLVM::InsertValueOp>( |
| op->getLoc(), realComplex, adaptor.getImag(), position); |
| |
| rewriter.replaceOp(op, complex); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRComplexRealOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ComplexRealOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ComplexRealOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ComplexRealOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resultLLVMTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( |
| op, resultLLVMTy, adaptor.getOperand(), |
| llvm::ArrayRef<std::int64_t>{0}); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRComplexImagOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ComplexImagOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ComplexImagOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ComplexImagOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resultLLVMTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>( |
| op, resultLLVMTy, adaptor.getOperand(), |
| llvm::ArrayRef<std::int64_t>{1}); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRComplexRealPtrOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ComplexRealPtrOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ComplexRealPtrOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ComplexRealPtrOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto operandTy = |
| mlir::cast<mlir::cir::PointerType>(op.getOperand().getType()); |
| auto resultLLVMTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| auto elementLLVMTy = |
| getTypeConverter()->convertType(operandTy.getPointee()); |
| |
| mlir::LLVM::GEPArg gepIndices[2]{{0}, {0}}; |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, |
| /*inbounds=*/true); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRComplexImagPtrOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ComplexImagPtrOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ComplexImagPtrOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ComplexImagPtrOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto operandTy = |
| mlir::cast<mlir::cir::PointerType>(op.getOperand().getType()); |
| auto resultLLVMTy = |
| getTypeConverter()->convertType(op.getResult().getType()); |
| auto elementLLVMTy = |
| getTypeConverter()->convertType(operandTy.getPointee()); |
| |
| mlir::LLVM::GEPArg gepIndices[2]{{0}, {1}}; |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices, |
| /*inbounds=*/true); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRSwitchFlatOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::SwitchFlatOp> { |
| public: |
| using OpConversionPattern<mlir::cir::SwitchFlatOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::SwitchFlatOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| |
| llvm::SmallVector<mlir::APInt, 8> caseValues; |
| if (op.getCaseValues()) { |
| for (auto val : op.getCaseValues()) { |
| auto intAttr = dyn_cast<mlir::cir::IntAttr>(val); |
| caseValues.push_back(intAttr.getValue()); |
| } |
| } |
| |
| llvm::SmallVector<mlir::Block *, 8> caseDestinations; |
| llvm::SmallVector<mlir::ValueRange, 8> caseOperands; |
| |
| for (auto x : op.getCaseDestinations()) { |
| caseDestinations.push_back(x); |
| } |
| |
| for (auto x : op.getCaseOperands()) { |
| caseOperands.push_back(x); |
| } |
| |
| // Set switch op to branch to the newly created blocks. |
| rewriter.setInsertionPoint(op); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>( |
| op, adaptor.getCondition(), op.getDefaultDestination(), |
| op.getDefaultOperands(), caseValues, caseDestinations, caseOperands); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRGlobalOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::GlobalOp> { |
| public: |
| using OpConversionPattern<mlir::cir::GlobalOp>::OpConversionPattern; |
| |
| // Get addrspace by converting a pointer type. |
| // TODO: The approach here is a little hacky. We should access the target info |
| // directly to convert the address space of global op, similar to what we do |
| // for type converter. |
| unsigned getGlobalOpTargetAddrSpace(mlir::cir::GlobalOp op) const { |
| auto tempPtrTy = mlir::cir::PointerType::get(getContext(), op.getSymType(), |
| op.getAddrSpaceAttr()); |
| return cast<mlir::LLVM::LLVMPointerType>( |
| typeConverter->convertType(tempPtrTy)) |
| .getAddressSpace(); |
| } |
| |
| /// Replace CIR global with a region initialized LLVM global and update |
| /// insertion point to the end of the initializer block. |
| inline void setupRegionInitializedLLVMGlobalOp( |
| mlir::cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const { |
| const auto llvmType = getTypeConverter()->convertType(op.getSymType()); |
| SmallVector<mlir::NamedAttribute> attributes; |
| auto newGlobalOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( |
| op, llvmType, op.getConstant(), convertLinkage(op.getLinkage()), |
| op.getSymName(), nullptr, |
| /*alignment*/ op.getAlignment().value_or(0), |
| /*addrSpace*/ getGlobalOpTargetAddrSpace(op), |
| /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), |
| /*comdat*/ mlir::SymbolRefAttr(), attributes); |
| newGlobalOp.getRegion().push_back(new mlir::Block()); |
| rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); |
| } |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::GlobalOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| |
| // Fetch required values to create LLVM op. |
| const auto llvmType = getTypeConverter()->convertType(op.getSymType()); |
| const auto isConst = op.getConstant(); |
| const auto isDsoLocal = op.getDsolocal(); |
| const auto linkage = convertLinkage(op.getLinkage()); |
| const auto symbol = op.getSymName(); |
| const auto loc = op.getLoc(); |
| std::optional<mlir::StringRef> section = op.getSection(); |
| std::optional<mlir::Attribute> init = op.getInitialValue(); |
| mlir::LLVM::VisibilityAttr visibility = mlir::LLVM::VisibilityAttr::get( |
| getContext(), lowerCIRVisibilityToLLVMVisibility( |
| op.getGlobalVisibilityAttr().getValue())); |
| |
| SmallVector<mlir::NamedAttribute> attributes; |
| if (section.has_value()) |
| attributes.push_back(rewriter.getNamedAttr( |
| "section", rewriter.getStringAttr(section.value()))); |
| |
| attributes.push_back(rewriter.getNamedAttr("visibility_", visibility)); |
| |
| // Check for missing funcionalities. |
| if (!init.has_value()) { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( |
| op, llvmType, isConst, linkage, symbol, mlir::Attribute(), |
| /*alignment*/ 0, /*addrSpace*/ getGlobalOpTargetAddrSpace(op), |
| /*dsoLocal*/ isDsoLocal, /*threadLocal*/ (bool)op.getTlsModelAttr(), |
| /*comdat*/ mlir::SymbolRefAttr(), attributes); |
| return mlir::success(); |
| } |
| |
| // Initializer is a constant array: convert it to a compatible llvm init. |
| if (auto constArr = |
| mlir::dyn_cast<mlir::cir::ConstArrayAttr>(init.value())) { |
| if (auto attr = mlir::dyn_cast<mlir::StringAttr>(constArr.getElts())) { |
| init = rewriter.getStringAttr(attr.getValue()); |
| } else if (auto attr = |
| mlir::dyn_cast<mlir::ArrayAttr>(constArr.getElts())) { |
| // Failed to use a compact attribute as an initializer: |
| // initialize elements individually. |
| if (!(init = lowerConstArrayAttr(constArr, getTypeConverter()))) { |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| rewriter.create<mlir::LLVM::ReturnOp>( |
| op->getLoc(), |
| lowerCirAttrAsValue(op, constArr, rewriter, typeConverter)); |
| return mlir::success(); |
| } |
| } else { |
| op.emitError() |
| << "unsupported lowering for #cir.const_array with value " |
| << constArr.getElts(); |
| return mlir::failure(); |
| } |
| } else if (auto fltAttr = mlir::dyn_cast<mlir::cir::FPAttr>(init.value())) { |
| // Initializer is a constant floating-point number: convert to MLIR |
| // builtin constant. |
| init = rewriter.getFloatAttr(llvmType, fltAttr.getValue()); |
| } |
| // Initializer is a constant integer: convert to MLIR builtin constant. |
| else if (auto intAttr = mlir::dyn_cast<mlir::cir::IntAttr>(init.value())) { |
| init = rewriter.getIntegerAttr(llvmType, intAttr.getValue()); |
| } else if (auto boolAttr = |
| mlir::dyn_cast<mlir::cir::BoolAttr>(init.value())) { |
| init = rewriter.getBoolAttr(boolAttr.getValue()); |
| } else if (isa<mlir::cir::ZeroAttr, mlir::cir::ConstPtrAttr>( |
| init.value())) { |
| // TODO(cir): once LLVM's dialect has a proper zeroinitializer attribute |
| // this should be updated. For now, we use a custom op to initialize |
| // globals to zero. |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| auto value = |
| lowerCirAttrAsValue(op, init.value(), rewriter, typeConverter); |
| rewriter.create<mlir::LLVM::ReturnOp>(loc, value); |
| return mlir::success(); |
| } else if (auto dataMemberAttr = |
| mlir::dyn_cast<mlir::cir::DataMemberAttr>(init.value())) { |
| init = lowerDataMemberAttr(op->getParentOfType<mlir::ModuleOp>(), |
| dataMemberAttr, *typeConverter); |
| } else if (const auto structAttr = |
| mlir::dyn_cast<mlir::cir::ConstStructAttr>(init.value())) { |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| rewriter.create<mlir::LLVM::ReturnOp>( |
| op->getLoc(), |
| lowerCirAttrAsValue(op, structAttr, rewriter, typeConverter)); |
| return mlir::success(); |
| } else if (auto attr = |
| mlir::dyn_cast<mlir::cir::GlobalViewAttr>(init.value())) { |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| rewriter.create<mlir::LLVM::ReturnOp>( |
| loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter)); |
| return mlir::success(); |
| } else if (const auto vtableAttr = |
| mlir::dyn_cast<mlir::cir::VTableAttr>(init.value())) { |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| rewriter.create<mlir::LLVM::ReturnOp>( |
| op->getLoc(), |
| lowerCirAttrAsValue(op, vtableAttr, rewriter, typeConverter)); |
| return mlir::success(); |
| } else if (const auto typeinfoAttr = |
| mlir::dyn_cast<mlir::cir::TypeInfoAttr>(init.value())) { |
| setupRegionInitializedLLVMGlobalOp(op, rewriter); |
| rewriter.create<mlir::LLVM::ReturnOp>( |
| op->getLoc(), |
| lowerCirAttrAsValue(op, typeinfoAttr, rewriter, typeConverter)); |
| return mlir::success(); |
| } else { |
| op.emitError() << "usupported initializer '" << init.value() << "'"; |
| return mlir::failure(); |
| } |
| |
| // Rewrite op. |
| auto llvmGlobalOp = rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>( |
| op, llvmType, isConst, linkage, symbol, init.value(), |
| /*alignment*/op.getAlignment().value_or(0), |
| /*addrSpace*/ getGlobalOpTargetAddrSpace(op), |
| /*dsoLocal*/ false, /*threadLocal*/ (bool)op.getTlsModelAttr(), |
| /*comdat*/ mlir::SymbolRefAttr(), attributes); |
| |
| auto mod = op->getParentOfType<mlir::ModuleOp>(); |
| if (op.getComdat()) |
| addComdat(llvmGlobalOp, comdatOp, rewriter, mod); |
| |
| return mlir::success(); |
| } |
| |
| private: |
| mutable mlir::LLVM::ComdatOp comdatOp = nullptr; |
| static void addComdat(mlir::LLVM::GlobalOp &op, |
| mlir::LLVM::ComdatOp &comdatOp, |
| mlir::OpBuilder &builder, mlir::ModuleOp &module) { |
| StringRef comdatName("__llvm_comdat_globals"); |
| if (!comdatOp) { |
| builder.setInsertionPointToStart(module.getBody()); |
| comdatOp = |
| builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName); |
| } |
| builder.setInsertionPointToStart(&comdatOp.getBody().back()); |
| auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>( |
| comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any); |
| op.setComdatAttr(mlir::SymbolRefAttr::get( |
| builder.getContext(), comdatName, |
| mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()))); |
| } |
| }; |
| |
| class CIRUnaryOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::UnaryOp> { |
| public: |
| using OpConversionPattern<mlir::cir::UnaryOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::UnaryOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| assert(op.getType() == op.getInput().getType() && |
| "Unary operation's operand type and result type are different"); |
| mlir::Type type = op.getType(); |
| mlir::Type elementType = elementTypeIfVector(type); |
| bool IsVector = mlir::isa<mlir::cir::VectorType>(type); |
| auto llvmType = getTypeConverter()->convertType(type); |
| auto loc = op.getLoc(); |
| |
| // Integer unary operations: + - ~ ++ -- |
| if (mlir::isa<mlir::cir::IntType>(elementType)) { |
| switch (op.getKind()) { |
| case mlir::cir::UnaryOpKind::Inc: { |
| assert(!IsVector && "++ not allowed on vector types"); |
| auto One = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmType, |
| adaptor.getInput(), One); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Dec: { |
| assert(!IsVector && "-- not allowed on vector types"); |
| auto One = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 1)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmType, |
| adaptor.getInput(), One); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Plus: { |
| rewriter.replaceOp(op, adaptor.getInput()); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Minus: { |
| mlir::Value Zero; |
| if (IsVector) |
| Zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType); |
| else |
| Zero = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 0)); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmType, Zero, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Not: { |
| // bit-wise compliment operator, implemented as an XOR with -1. |
| mlir::Value MinusOne; |
| if (IsVector) { |
| // Creating a vector object with all -1 values is easier said than |
| // done. It requires a series of insertelement ops. |
| mlir::Type llvmElementType = |
| getTypeConverter()->convertType(elementType); |
| auto MinusOneInt = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmElementType, |
| mlir::IntegerAttr::get(llvmElementType, -1)); |
| MinusOne = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmType); |
| auto NumElements = |
| mlir::dyn_cast<mlir::cir::VectorType>(type).getSize(); |
| for (uint64_t i = 0; i < NumElements; ++i) { |
| mlir::Value indexValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI64Type(), i); |
| MinusOne = rewriter.create<mlir::LLVM::InsertElementOp>( |
| loc, MinusOne, MinusOneInt, indexValue); |
| } |
| } else { |
| MinusOne = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, -1)); |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmType, MinusOne, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| } |
| } |
| |
| // Floating point unary operations: + - ++ -- |
| if (mlir::isa<mlir::cir::CIRFPTypeInterface>(elementType)) { |
| switch (op.getKind()) { |
| case mlir::cir::UnaryOpKind::Inc: { |
| assert(!IsVector && "++ not allowed on vector types"); |
| auto oneAttr = rewriter.getFloatAttr(llvmType, 1.0); |
| auto oneConst = |
| rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, oneAttr); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, oneConst, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Dec: { |
| assert(!IsVector && "-- not allowed on vector types"); |
| auto negOneAttr = rewriter.getFloatAttr(llvmType, -1.0); |
| auto negOneConst = |
| rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, negOneAttr); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>( |
| op, llvmType, negOneConst, adaptor.getInput()); |
| return mlir::success(); |
| } |
| case mlir::cir::UnaryOpKind::Plus: |
| rewriter.replaceOp(op, adaptor.getInput()); |
| return mlir::success(); |
| case mlir::cir::UnaryOpKind::Minus: { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType, |
| adaptor.getInput()); |
| return mlir::success(); |
| } |
| default: |
| return op.emitError() |
| << "Unknown floating-point unary operation during CIR lowering"; |
| } |
| } |
| |
| // Boolean unary operations: ! only. (For all others, the operand has |
| // already been promoted to int.) |
| if (mlir::isa<mlir::cir::BoolType>(elementType)) { |
| switch (op.getKind()) { |
| case mlir::cir::UnaryOpKind::Not: |
| assert(!IsVector && "NYI: op! on vector mask"); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>( |
| op, llvmType, adaptor.getInput(), |
| rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, llvmType, mlir::IntegerAttr::get(llvmType, 1))); |
| return mlir::success(); |
| default: |
| return op.emitError() |
| << "Unknown boolean unary operation during CIR lowering"; |
| } |
| } |
| |
| // Pointer unary operations: + only. (++ and -- of pointers are implemented |
| // with cir.ptr_stride, not cir.unary.) |
| if (mlir::isa<mlir::cir::PointerType>(elementType)) { |
| switch (op.getKind()) { |
| case mlir::cir::UnaryOpKind::Plus: |
| rewriter.replaceOp(op, adaptor.getInput()); |
| return mlir::success(); |
| default: |
| op.emitError() << "Unknown pointer unary operation during CIR lowering"; |
| return mlir::failure(); |
| } |
| } |
| |
| return op.emitError() << "Unary operation has unsupported type: " |
| << elementType; |
| } |
| }; |
| |
| class CIRBinOpLowering : public mlir::OpConversionPattern<mlir::cir::BinOp> { |
| |
| mlir::LLVM::IntegerOverflowFlags |
| getIntOverflowFlag(mlir::cir::BinOp op) const { |
| if (op.getNoUnsignedWrap()) |
| return mlir::LLVM::IntegerOverflowFlags::nuw; |
| |
| if (op.getNoSignedWrap()) |
| return mlir::LLVM::IntegerOverflowFlags::nsw; |
| |
| return mlir::LLVM::IntegerOverflowFlags::none; |
| } |
| |
| public: |
| using OpConversionPattern<mlir::cir::BinOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BinOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| assert((op.getLhs().getType() == op.getRhs().getType()) && |
| "inconsistent operands' types not supported yet"); |
| mlir::Type type = op.getRhs().getType(); |
| assert((mlir::isa<mlir::cir::IntType, mlir::cir::CIRFPTypeInterface, |
| mlir::cir::VectorType>(type)) && |
| "operand type not supported yet"); |
| |
| auto llvmTy = getTypeConverter()->convertType(op.getType()); |
| auto rhs = adaptor.getRhs(); |
| auto lhs = adaptor.getLhs(); |
| |
| type = elementTypeIfVector(type); |
| |
| switch (op.getKind()) { |
| case mlir::cir::BinOpKind::Add: |
| if (mlir::isa<mlir::cir::IntType>(type)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Sub: |
| if (mlir::isa<mlir::cir::IntType>(type)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Mul: |
| if (mlir::isa<mlir::cir::IntType>(type)) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs, |
| getIntOverflowFlag(op)); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Div: |
| if (auto ty = mlir::dyn_cast<mlir::cir::IntType>(type)) { |
| if (ty.isUnsigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, llvmTy, lhs, rhs); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, llvmTy, lhs, rhs); |
| } else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Rem: |
| if (auto ty = mlir::dyn_cast<mlir::cir::IntType>(type)) { |
| if (ty.isUnsigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, llvmTy, lhs, rhs); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, llvmTy, lhs, rhs); |
| } else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::And: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Or: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, llvmTy, lhs, rhs); |
| break; |
| case mlir::cir::BinOpKind::Xor: |
| rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, llvmTy, lhs, rhs); |
| break; |
| } |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBinOpOverflowOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BinOpOverflowOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BinOpOverflowOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BinOpOverflowOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto arithKind = op.getKind(); |
| auto operandTy = op.getLhs().getType(); |
| auto resultTy = op.getResult().getType(); |
| |
| auto encompassedTyInfo = computeEncompassedTypeWidth(operandTy, resultTy); |
| auto encompassedLLVMTy = rewriter.getIntegerType(encompassedTyInfo.width); |
| |
| auto lhs = adaptor.getLhs(); |
| auto rhs = adaptor.getRhs(); |
| if (operandTy.getWidth() < encompassedTyInfo.width) { |
| if (operandTy.isSigned()) { |
| lhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, lhs); |
| rhs = rewriter.create<mlir::LLVM::SExtOp>(loc, encompassedLLVMTy, rhs); |
| } else { |
| lhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, lhs); |
| rhs = rewriter.create<mlir::LLVM::ZExtOp>(loc, encompassedLLVMTy, rhs); |
| } |
| } |
| |
| auto intrinName = getLLVMIntrinName(arithKind, encompassedTyInfo.sign, |
| encompassedTyInfo.width); |
| auto intrinNameAttr = mlir::StringAttr::get(op.getContext(), intrinName); |
| |
| auto overflowLLVMTy = rewriter.getI1Type(); |
| auto intrinRetTy = mlir::LLVM::LLVMStructType::getLiteral( |
| rewriter.getContext(), {encompassedLLVMTy, overflowLLVMTy}); |
| |
| auto callLLVMIntrinOp = rewriter.create<mlir::LLVM::CallIntrinsicOp>( |
| loc, intrinRetTy, intrinNameAttr, mlir::ValueRange{lhs, rhs}); |
| auto intrinRet = callLLVMIntrinOp.getResult(0); |
| |
| auto result = rewriter |
| .create<mlir::LLVM::ExtractValueOp>(loc, intrinRet, |
| ArrayRef<int64_t>{0}) |
| .getResult(); |
| auto overflow = rewriter |
| .create<mlir::LLVM::ExtractValueOp>( |
| loc, intrinRet, ArrayRef<int64_t>{1}) |
| .getResult(); |
| |
| if (resultTy.getWidth() < encompassedTyInfo.width) { |
| auto resultLLVMTy = getTypeConverter()->convertType(resultTy); |
| auto truncResult = |
| rewriter.create<mlir::LLVM::TruncOp>(loc, resultLLVMTy, result); |
| |
| // Extend the truncated result back to the encompassing type to check for |
| // any overflows during the truncation. |
| mlir::Value truncResultExt; |
| if (resultTy.isSigned()) |
| truncResultExt = rewriter.create<mlir::LLVM::SExtOp>( |
| loc, encompassedLLVMTy, truncResult); |
| else |
| truncResultExt = rewriter.create<mlir::LLVM::ZExtOp>( |
| loc, encompassedLLVMTy, truncResult); |
| auto truncOverflow = rewriter.create<mlir::LLVM::ICmpOp>( |
| loc, mlir::LLVM::ICmpPredicate::ne, truncResultExt, result); |
| |
| result = truncResult; |
| overflow = |
| rewriter.create<mlir::LLVM::OrOp>(loc, overflow, truncOverflow); |
| } |
| |
| auto boolLLVMTy = |
| getTypeConverter()->convertType(op.getOverflow().getType()); |
| if (boolLLVMTy != rewriter.getI1Type()) |
| overflow = rewriter.create<mlir::LLVM::ZExtOp>(loc, boolLLVMTy, overflow); |
| |
| rewriter.replaceOp(op, mlir::ValueRange{result, overflow}); |
| |
| return mlir::success(); |
| } |
| |
| private: |
| static std::string getLLVMIntrinName(mlir::cir::BinOpOverflowKind opKind, |
| bool isSigned, unsigned width) { |
| // The intrinsic name is `@llvm.{s|u}{opKind}.with.overflow.i{width}` |
| |
| std::string name = "llvm."; |
| |
| if (isSigned) |
| name.push_back('s'); |
| else |
| name.push_back('u'); |
| |
| switch (opKind) { |
| case mlir::cir::BinOpOverflowKind::Add: |
| name.append("add."); |
| break; |
| case mlir::cir::BinOpOverflowKind::Sub: |
| name.append("sub."); |
| break; |
| case mlir::cir::BinOpOverflowKind::Mul: |
| name.append("mul."); |
| break; |
| } |
| |
| name.append("with.overflow.i"); |
| name.append(std::to_string(width)); |
| |
| return name; |
| } |
| |
| struct EncompassedTypeInfo { |
| bool sign; |
| unsigned width; |
| }; |
| |
| static EncompassedTypeInfo |
| computeEncompassedTypeWidth(mlir::cir::IntType operandTy, |
| mlir::cir::IntType resultTy) { |
| auto sign = operandTy.getIsSigned() || resultTy.getIsSigned(); |
| auto width = |
| std::max(operandTy.getWidth() + (sign && operandTy.isUnsigned()), |
| resultTy.getWidth() + (sign && resultTy.isUnsigned())); |
| return {sign, width}; |
| } |
| }; |
| |
| class CIRShiftOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ShiftOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ShiftOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ShiftOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto cirAmtTy = |
| mlir::dyn_cast<mlir::cir::IntType>(op.getAmount().getType()); |
| auto cirValTy = mlir::dyn_cast<mlir::cir::IntType>(op.getValue().getType()); |
| auto llvmTy = getTypeConverter()->convertType(op.getType()); |
| mlir::Value amt = adaptor.getAmount(); |
| mlir::Value val = adaptor.getValue(); |
| |
| assert(cirValTy && cirAmtTy && "non-integer shift is NYI"); |
| assert(cirValTy == op.getType() && "inconsistent operands' types NYI"); |
| |
| // Ensure shift amount is the same type as the value. Some undefined |
| // behavior might occur in the casts below as per [C99 6.5.7.3]. |
| amt = getLLVMIntCast(rewriter, amt, mlir::cast<mlir::IntegerType>(llvmTy), |
| !cirAmtTy.isSigned(), cirValTy.getWidth()); |
| |
| // Lower to the proper LLVM shift operation. |
| if (op.getIsShiftleft()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt); |
| else { |
| if (cirValTy.isUnsigned()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt); |
| } |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> { |
| public: |
| using OpConversionPattern<mlir::cir::CmpOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto type = cmpOp.getLhs().getType(); |
| mlir::Value llResult; |
| |
| // Lower to LLVM comparison op. |
| if (auto intTy = mlir::dyn_cast<mlir::cir::IntType>(type)) { |
| auto kind = |
| convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned()); |
| llResult = rewriter.create<mlir::LLVM::ICmpOp>( |
| cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else if (auto ptrTy = mlir::dyn_cast<mlir::cir::PointerType>(type)) { |
| auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(), |
| /* isSigned=*/false); |
| llResult = rewriter.create<mlir::LLVM::ICmpOp>( |
| cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else if (mlir::isa<mlir::cir::CIRFPTypeInterface>(type)) { |
| auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind()); |
| llResult = rewriter.create<mlir::LLVM::FCmpOp>( |
| cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs()); |
| } else { |
| return cmpOp.emitError() << "unsupported type for CmpOp: " << type; |
| } |
| |
| // LLVM comparison ops return i1, but cir::CmpOp returns the same type as |
| // the LHS value. Since this return value can be used later, we need to |
| // restore the type with the extension below. |
| auto llResultTy = getTypeConverter()->convertType(cmpOp.getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(cmpOp, llResultTy, |
| llResult); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| static mlir::LLVM::CallIntrinsicOp |
| createCallLLVMIntrinsicOp(mlir::ConversionPatternRewriter &rewriter, |
| mlir::Location loc, const llvm::Twine &intrinsicName, |
| mlir::Type resultTy, mlir::ValueRange operands) { |
| auto intrinsicNameAttr = |
| mlir::StringAttr::get(rewriter.getContext(), intrinsicName); |
| return rewriter.create<mlir::LLVM::CallIntrinsicOp>( |
| loc, resultTy, intrinsicNameAttr, operands); |
| } |
| |
| static mlir::LLVM::CallIntrinsicOp replaceOpWithCallLLVMIntrinsicOp( |
| mlir::ConversionPatternRewriter &rewriter, mlir::Operation *op, |
| const llvm::Twine &intrinsicName, mlir::Type resultTy, |
| mlir::ValueRange operands) { |
| auto callIntrinOp = createCallLLVMIntrinsicOp( |
| rewriter, op->getLoc(), intrinsicName, resultTy, operands); |
| rewriter.replaceOp(op, callIntrinOp.getOperation()); |
| return callIntrinOp; |
| } |
| |
| static mlir::Value createLLVMBitOp(mlir::Location loc, |
| const llvm::Twine &llvmIntrinBaseName, |
| mlir::Type resultTy, mlir::Value operand, |
| std::optional<bool> poisonZeroInputFlag, |
| mlir::ConversionPatternRewriter &rewriter) { |
| auto operandIntTy = mlir::cast<mlir::IntegerType>(operand.getType()); |
| auto resultIntTy = mlir::cast<mlir::IntegerType>(resultTy); |
| |
| std::string llvmIntrinName = |
| llvmIntrinBaseName.concat(".i") |
| .concat(std::to_string(operandIntTy.getWidth())) |
| .str(); |
| |
| // Note that LLVM intrinsic calls to bit intrinsics have the same type as the |
| // operand. |
| mlir::LLVM::CallIntrinsicOp op; |
| if (poisonZeroInputFlag.has_value()) { |
| auto poisonZeroInputValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI1Type(), static_cast<int64_t>(*poisonZeroInputFlag)); |
| op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName, |
| operand.getType(), |
| {operand, poisonZeroInputValue}); |
| } else { |
| op = createCallLLVMIntrinsicOp(rewriter, loc, llvmIntrinName, |
| operand.getType(), operand); |
| } |
| |
| return getLLVMIntCast(rewriter, op->getResult(0), |
| mlir::cast<mlir::IntegerType>(resultTy), |
| /*isUnsigned=*/true, resultIntTy.getWidth()); |
| } |
| |
| class CIRBitClrsbOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitClrsbOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitClrsbOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitClrsbOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto zero = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), adaptor.getInput().getType(), 0); |
| auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>( |
| op.getLoc(), |
| mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), |
| mlir::LLVM::ICmpPredicate::slt), |
| adaptor.getInput(), zero); |
| |
| auto negOne = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), adaptor.getInput().getType(), -1); |
| auto flipped = rewriter.create<mlir::LLVM::XOrOp>( |
| op.getLoc(), adaptor.getInput(), negOne); |
| |
| auto select = rewriter.create<mlir::LLVM::SelectOp>( |
| op.getLoc(), isNeg, flipped, adaptor.getInput()); |
| |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto clz = createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, select, |
| /*poisonZeroInputFlag=*/false, rewriter); |
| |
| auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); |
| auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one); |
| rewriter.replaceOp(op, res); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRObjSizeOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ObjSizeOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ObjSizeOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ObjSizeOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto llvmResTy = getTypeConverter()->convertType(op.getType()); |
| auto loc = op->getLoc(); |
| |
| mlir::cir::SizeInfoType kindInfo = op.getKind(); |
| auto falseValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI1Type(), false); |
| auto trueValue = rewriter.create<mlir::LLVM::ConstantOp>( |
| loc, rewriter.getI1Type(), true); |
| |
| replaceOpWithCallLLVMIntrinsicOp( |
| rewriter, op, "llvm.objectsize", llvmResTy, |
| mlir::ValueRange{adaptor.getPtr(), |
| kindInfo == mlir::cir::SizeInfoType::max ? falseValue |
| : trueValue, |
| trueValue, op.getDynamic() ? trueValue : falseValue}); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBitClzOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitClzOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitClzOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitClzOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto llvmOp = |
| createLLVMBitOp(op.getLoc(), "llvm.ctlz", resTy, adaptor.getInput(), |
| /*poisonZeroInputFlag=*/true, rewriter); |
| rewriter.replaceOp(op, llvmOp); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBitCtzOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitCtzOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitCtzOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitCtzOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto llvmOp = |
| createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), |
| /*poisonZeroInputFlag=*/true, rewriter); |
| rewriter.replaceOp(op, llvmOp); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBitFfsOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitFfsOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitFfsOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitFfsOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto ctz = |
| createLLVMBitOp(op.getLoc(), "llvm.cttz", resTy, adaptor.getInput(), |
| /*poisonZeroInputFlag=*/false, rewriter); |
| |
| auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); |
| auto ctzAddOne = rewriter.create<mlir::LLVM::AddOp>(op.getLoc(), ctz, one); |
| |
| auto zeroInputTy = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), adaptor.getInput().getType(), 0); |
| auto isZero = rewriter.create<mlir::LLVM::ICmpOp>( |
| op.getLoc(), |
| mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), |
| mlir::LLVM::ICmpPredicate::eq), |
| adaptor.getInput(), zeroInputTy); |
| |
| auto zero = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 0); |
| auto res = rewriter.create<mlir::LLVM::SelectOp>(op.getLoc(), isZero, zero, |
| ctzAddOne); |
| rewriter.replaceOp(op, res); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBitParityOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitParityOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitParityOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitParityOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto popcnt = |
| createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), |
| /*poisonZeroInputFlag=*/std::nullopt, rewriter); |
| |
| auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1); |
| auto popcntMod2 = |
| rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one); |
| rewriter.replaceOp(op, popcntMod2); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRBitPopcountOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::BitPopcountOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BitPopcountOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BitPopcountOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto llvmOp = |
| createLLVMBitOp(op.getLoc(), "llvm.ctpop", resTy, adaptor.getInput(), |
| /*poisonZeroInputFlag=*/std::nullopt, rewriter); |
| rewriter.replaceOp(op, llvmOp); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| static mlir::LLVM::AtomicOrdering getLLVMAtomicOrder(mlir::cir::MemOrder memo) { |
| switch (memo) { |
| case mlir::cir::MemOrder::Relaxed: |
| return mlir::LLVM::AtomicOrdering::monotonic; |
| case mlir::cir::MemOrder::Consume: |
| case mlir::cir::MemOrder::Acquire: |
| return mlir::LLVM::AtomicOrdering::acquire; |
| case mlir::cir::MemOrder::Release: |
| return mlir::LLVM::AtomicOrdering::release; |
| case mlir::cir::MemOrder::AcquireRelease: |
| return mlir::LLVM::AtomicOrdering::acq_rel; |
| case mlir::cir::MemOrder::SequentiallyConsistent: |
| return mlir::LLVM::AtomicOrdering::seq_cst; |
| } |
| llvm_unreachable("shouldn't get here"); |
| } |
| |
| class CIRAtomicCmpXchgLowering |
| : public mlir::OpConversionPattern<mlir::cir::AtomicCmpXchg> { |
| public: |
| using OpConversionPattern<mlir::cir::AtomicCmpXchg>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::AtomicCmpXchg op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto expected = adaptor.getExpected(); |
| auto desired = adaptor.getDesired(); |
| |
| // FIXME: add syncscope. |
| auto cmpxchg = rewriter.create<mlir::LLVM::AtomicCmpXchgOp>( |
| op.getLoc(), adaptor.getPtr(), expected, desired, |
| getLLVMAtomicOrder(adaptor.getSuccOrder()), |
| getLLVMAtomicOrder(adaptor.getFailOrder())); |
| cmpxchg.setWeak(adaptor.getWeak()); |
| cmpxchg.setVolatile_(adaptor.getIsVolatile()); |
| |
| // Check result and apply stores accordingly. |
| auto old = rewriter.create<mlir::LLVM::ExtractValueOp>( |
| op.getLoc(), cmpxchg.getResult(), 0); |
| auto cmp = rewriter.create<mlir::LLVM::ExtractValueOp>( |
| op.getLoc(), cmpxchg.getResult(), 1); |
| |
| auto extCmp = rewriter.create<mlir::LLVM::ZExtOp>( |
| op.getLoc(), rewriter.getI8Type(), cmp); |
| rewriter.replaceOp(op, {old, extCmp}); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRAtomicXchgLowering |
| : public mlir::OpConversionPattern<mlir::cir::AtomicXchg> { |
| public: |
| using OpConversionPattern<mlir::cir::AtomicXchg>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::AtomicXchg op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // FIXME: add syncscope. |
| auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AtomicRMWOp>( |
| op, mlir::LLVM::AtomicBinOp::xchg, adaptor.getPtr(), adaptor.getVal(), |
| llvmOrder); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRAtomicFetchLowering |
| : public mlir::OpConversionPattern<mlir::cir::AtomicFetch> { |
| public: |
| using OpConversionPattern<mlir::cir::AtomicFetch>::OpConversionPattern; |
| |
| mlir::Value buildPostOp(mlir::cir::AtomicFetch op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter, |
| mlir::Value rmwVal, bool isInt) const { |
| SmallVector<mlir::Value> atomicOperands = {rmwVal, adaptor.getVal()}; |
| SmallVector<mlir::Type> atomicResTys = {rmwVal.getType()}; |
| return rewriter |
| .create(op.getLoc(), |
| rewriter.getStringAttr(getLLVMBinop(op.getBinop(), isInt)), |
| atomicOperands, atomicResTys, {}) |
| ->getResult(0); |
| } |
| |
| mlir::Value buildMinMaxPostOp(mlir::cir::AtomicFetch op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter, |
| mlir::Value rmwVal, bool isSigned) const { |
| auto loc = op.getLoc(); |
| mlir::LLVM::ICmpPredicate pred; |
| if (op.getBinop() == mlir::cir::AtomicFetchKind::Max) { |
| pred = isSigned ? mlir::LLVM::ICmpPredicate::sgt |
| : mlir::LLVM::ICmpPredicate::ugt; |
| } else { // Min |
| pred = isSigned ? mlir::LLVM::ICmpPredicate::slt |
| : mlir::LLVM::ICmpPredicate::ult; |
| } |
| |
| auto cmp = rewriter.create<mlir::LLVM::ICmpOp>( |
| loc, mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(), pred), |
| rmwVal, adaptor.getVal()); |
| return rewriter.create<mlir::LLVM::SelectOp>(loc, cmp, rmwVal, |
| adaptor.getVal()); |
| } |
| |
| llvm::StringLiteral getLLVMBinop(mlir::cir::AtomicFetchKind k, |
| bool isInt) const { |
| switch (k) { |
| case mlir::cir::AtomicFetchKind::Add: |
| return isInt ? mlir::LLVM::AddOp::getOperationName() |
| : mlir::LLVM::FAddOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::Sub: |
| return isInt ? mlir::LLVM::SubOp::getOperationName() |
| : mlir::LLVM::FSubOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::And: |
| return mlir::LLVM::AndOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::Xor: |
| return mlir::LLVM::XOrOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::Or: |
| return mlir::LLVM::OrOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::Nand: |
| // There's no nand binop in LLVM, this is later fixed with a not. |
| return mlir::LLVM::AndOp::getOperationName(); |
| case mlir::cir::AtomicFetchKind::Max: |
| case mlir::cir::AtomicFetchKind::Min: |
| llvm_unreachable("handled in buildMinMaxPostOp"); |
| } |
| llvm_unreachable("Unknown atomic fetch opcode"); |
| } |
| |
| mlir::LLVM::AtomicBinOp getLLVMAtomicBinOp(mlir::cir::AtomicFetchKind k, |
| bool isInt, |
| bool isSignedInt) const { |
| switch (k) { |
| case mlir::cir::AtomicFetchKind::Add: |
| return isInt ? mlir::LLVM::AtomicBinOp::add |
| : mlir::LLVM::AtomicBinOp::fadd; |
| case mlir::cir::AtomicFetchKind::Sub: |
| return isInt ? mlir::LLVM::AtomicBinOp::sub |
| : mlir::LLVM::AtomicBinOp::fsub; |
| case mlir::cir::AtomicFetchKind::And: |
| return mlir::LLVM::AtomicBinOp::_and; |
| case mlir::cir::AtomicFetchKind::Xor: |
| return mlir::LLVM::AtomicBinOp::_xor; |
| case mlir::cir::AtomicFetchKind::Or: |
| return mlir::LLVM::AtomicBinOp::_or; |
| case mlir::cir::AtomicFetchKind::Nand: |
| return mlir::LLVM::AtomicBinOp::nand; |
| case mlir::cir::AtomicFetchKind::Max: { |
| if (!isInt) |
| return mlir::LLVM::AtomicBinOp::fmax; |
| return isSignedInt ? mlir::LLVM::AtomicBinOp::max |
| : mlir::LLVM::AtomicBinOp::umax; |
| } |
| case mlir::cir::AtomicFetchKind::Min: { |
| if (!isInt) |
| return mlir::LLVM::AtomicBinOp::fmin; |
| return isSignedInt ? mlir::LLVM::AtomicBinOp::min |
| : mlir::LLVM::AtomicBinOp::umin; |
| } |
| } |
| llvm_unreachable("Unknown atomic fetch opcode"); |
| } |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::AtomicFetch op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| |
| bool isInt, isSignedInt = false; // otherwise it's float. |
| if (auto intTy = |
| mlir::dyn_cast<mlir::cir::IntType>(op.getVal().getType())) { |
| isInt = true; |
| isSignedInt = intTy.isSigned(); |
| } else if (mlir::isa<mlir::cir::SingleType, mlir::cir::DoubleType>( |
| op.getVal().getType())) |
| isInt = false; |
| else { |
| return op.emitError() |
| << "Unsupported type: " << adaptor.getVal().getType(); |
| } |
| |
| // FIXME: add syncscope. |
| auto llvmOrder = getLLVMAtomicOrder(adaptor.getMemOrder()); |
| auto llvmBinOpc = getLLVMAtomicBinOp(op.getBinop(), isInt, isSignedInt); |
| auto rmwVal = rewriter.create<mlir::LLVM::AtomicRMWOp>( |
| op.getLoc(), llvmBinOpc, adaptor.getPtr(), adaptor.getVal(), llvmOrder); |
| |
| mlir::Value result = rmwVal.getRes(); |
| if (!op.getFetchFirst()) { |
| if (op.getBinop() == mlir::cir::AtomicFetchKind::Max || |
| op.getBinop() == mlir::cir::AtomicFetchKind::Min) |
| result = buildMinMaxPostOp(op, adaptor, rewriter, rmwVal.getRes(), |
| isSignedInt); |
| else |
| result = buildPostOp(op, adaptor, rewriter, rmwVal.getRes(), isInt); |
| |
| // Compensate lack of nand binop in LLVM IR. |
| if (op.getBinop() == mlir::cir::AtomicFetchKind::Nand) { |
| auto negOne = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), result.getType(), -1); |
| result = |
| rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(), result, negOne); |
| } |
| } |
| |
| rewriter.replaceOp(op, result); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRByteswapOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ByteswapOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ByteswapOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ByteswapOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Note that LLVM intrinsic calls to @llvm.bswap.i* have the same type as |
| // the operand. |
| |
| auto resTy = mlir::cast<mlir::IntegerType>( |
| getTypeConverter()->convertType(op.getType())); |
| |
| std::string llvmIntrinName = "llvm.bswap.i"; |
| llvmIntrinName.append(std::to_string(resTy.getWidth())); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ByteSwapOp>(op, adaptor.getInput()); |
| |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRRotateOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::RotateOp> { |
| public: |
| using OpConversionPattern<mlir::cir::RotateOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::RotateOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Note that LLVM intrinsic calls to @llvm.fsh{r,l}.i* have the same type as |
| // the operand. |
| auto src = adaptor.getSrc(); |
| if (op.getLeft()) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FshlOp>(op, src, src, |
| adaptor.getAmt()); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::FshrOp>(op, src, src, |
| adaptor.getAmt()); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRSelectOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::SelectOp> { |
| public: |
| using OpConversionPattern<mlir::cir::SelectOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::SelectOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto getConstantBool = [](mlir::Value value) -> std::optional<bool> { |
| auto definingOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>( |
| value.getDefiningOp()); |
| if (!definingOp) |
| return std::nullopt; |
| |
| auto constValue = |
| mlir::dyn_cast<mlir::cir::BoolAttr>(definingOp.getValue()); |
| if (!constValue) |
| return std::nullopt; |
| |
| return constValue.getValue(); |
| }; |
| |
| // Two special cases in the LLVMIR codegen of select op: |
| // - select %0, %1, false => and %0, %1 |
| // - select %0, true, %1 => or %0, %1 |
| auto trueValue = op.getTrueValue(); |
| auto falseValue = op.getFalseValue(); |
| if (mlir::isa<mlir::cir::BoolType>(trueValue.getType())) { |
| if (std::optional<bool> falseValueBool = getConstantBool(falseValue); |
| falseValueBool.has_value() && !*falseValueBool) { |
| // select %0, %1, false => and %0, %1 |
| rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>( |
| op, adaptor.getCondition(), adaptor.getTrueValue()); |
| return mlir::success(); |
| } |
| if (std::optional<bool> trueValueBool = getConstantBool(trueValue); |
| trueValueBool.has_value() && *trueValueBool) { |
| // select %0, true, %1 => or %0, %1 |
| rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>( |
| op, adaptor.getCondition(), adaptor.getFalseValue()); |
| return mlir::success(); |
| } |
| } |
| |
| auto llvmCondition = rewriter.create<mlir::LLVM::TruncOp>( |
| op.getLoc(), mlir::IntegerType::get(op->getContext(), 1), |
| adaptor.getCondition()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>( |
| op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue()); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRBrOpLowering : public mlir::OpConversionPattern<mlir::cir::BrOp> { |
| public: |
| using OpConversionPattern<mlir::cir::BrOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::BrOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(), |
| op.getDest()); |
| return mlir::LogicalResult::success(); |
| } |
| }; |
| |
| class CIRGetMemberOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::GetMemberOp> { |
| public: |
| using mlir::OpConversionPattern<mlir::cir::GetMemberOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::GetMemberOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto llResTy = getTypeConverter()->convertType(op.getType()); |
| const auto structTy = |
| mlir::cast<mlir::cir::StructType>(op.getAddrTy().getPointee()); |
| assert(structTy && "expected struct type"); |
| |
| switch (structTy.getKind()) { |
| case mlir::cir::StructType::Struct: |
| case mlir::cir::StructType::Class: { |
| // Since the base address is a pointer to an aggregate, the first offset |
| // is always zero. The second offset tell us which member it will access. |
| llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()}; |
| const auto elementTy = getTypeConverter()->convertType(structTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy, |
| adaptor.getAddr(), offset); |
| return mlir::success(); |
| } |
| case mlir::cir::StructType::Union: |
| // Union members share the address space, so we just need a bitcast to |
| // conform to type-checking. |
| rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy, |
| adaptor.getAddr()); |
| return mlir::success(); |
| } |
| } |
| }; |
| |
| class CIRGetRuntimeMemberOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::GetRuntimeMemberOp> { |
| public: |
| using mlir::OpConversionPattern< |
| mlir::cir::GetRuntimeMemberOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::GetRuntimeMemberOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto llvmResTy = getTypeConverter()->convertType(op.getType()); |
| auto llvmElementTy = mlir::IntegerType::get(op.getContext(), 8); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( |
| op, llvmResTy, llvmElementTy, adaptor.getAddr(), adaptor.getMember()); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRPtrDiffOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::PtrDiffOp> { |
| public: |
| using OpConversionPattern<mlir::cir::PtrDiffOp>::OpConversionPattern; |
| |
| uint64_t getTypeSize(mlir::Type type, mlir::Operation &op) const { |
| mlir::DataLayout layout(op.getParentOfType<mlir::ModuleOp>()); |
| // For LLVM purposes we treat void as u8. |
| if (isa<mlir::cir::VoidType>(type)) |
| type = mlir::cir::IntType::get(type.getContext(), 8, /*isSigned=*/false); |
| return llvm::divideCeil(layout.getTypeSizeInBits(type), 8); |
| } |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::PtrDiffOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto dstTy = mlir::cast<mlir::cir::IntType>(op.getType()); |
| auto llvmDstTy = getTypeConverter()->convertType(dstTy); |
| |
| auto lhs = rewriter.create<mlir::LLVM::PtrToIntOp>(op.getLoc(), llvmDstTy, |
| adaptor.getLhs()); |
| auto rhs = rewriter.create<mlir::LLVM::PtrToIntOp>(op.getLoc(), llvmDstTy, |
| adaptor.getRhs()); |
| |
| auto diff = |
| rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), llvmDstTy, lhs, rhs); |
| |
| auto ptrTy = mlir::cast<mlir::cir::PointerType>(op.getLhs().getType()); |
| auto typeSize = getTypeSize(ptrTy.getPointee(), *op); |
| |
| // Avoid silly division by 1. |
| auto resultVal = diff.getResult(); |
| if (typeSize != 1) { |
| auto typeSizeVal = rewriter.create<mlir::LLVM::ConstantOp>( |
| op.getLoc(), llvmDstTy, mlir::IntegerAttr::get(llvmDstTy, typeSize)); |
| |
| if (dstTy.isUnsigned()) |
| resultVal = rewriter.create<mlir::LLVM::UDivOp>(op.getLoc(), llvmDstTy, |
| diff, typeSizeVal); |
| else |
| resultVal = rewriter.create<mlir::LLVM::SDivOp>(op.getLoc(), llvmDstTy, |
| diff, typeSizeVal); |
| } |
| rewriter.replaceOp(op, resultVal); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRExpectOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ExpectOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ExpectOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ExpectOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| std::optional<llvm::APFloat> prob = op.getProb(); |
| if (!prob) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectOp>(op, adaptor.getVal(), |
| adaptor.getExpected()); |
| else |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectWithProbabilityOp>( |
| op, adaptor.getVal(), adaptor.getExpected(), prob.value()); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRVTableAddrPointOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::VTableAddrPointOp> { |
| public: |
| using OpConversionPattern<mlir::cir::VTableAddrPointOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::VTableAddrPointOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| const auto *converter = getTypeConverter(); |
| auto targetType = converter->convertType(op.getType()); |
| mlir::Value symAddr = op.getSymAddr(); |
| llvm::SmallVector<mlir::LLVM::GEPArg> offsets; |
| mlir::Type eltType; |
| if (!symAddr) { |
| // Get the vtable address point from a global variable |
| auto module = op->getParentOfType<mlir::ModuleOp>(); |
| auto *symbol = |
| mlir::SymbolTable::lookupSymbolIn(module, op.getNameAttr()); |
| if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(symbol)) { |
| eltType = llvmSymbol.getType(); |
| } else if (auto cirSymbol = dyn_cast<mlir::cir::GlobalOp>(symbol)) { |
| eltType = converter->convertType(cirSymbol.getSymType()); |
| } |
| symAddr = rewriter.create<mlir::LLVM::AddressOfOp>( |
| op.getLoc(), mlir::LLVM::LLVMPointerType::get(getContext()), |
| *op.getName()); |
| offsets = llvm::SmallVector<mlir::LLVM::GEPArg>{ |
| 0, op.getVtableIndex(), op.getAddressPointIndex()}; |
| } else { |
| // Get indirect vtable address point retrieval |
| symAddr = adaptor.getSymAddr(); |
| eltType = converter->convertType(symAddr.getType()); |
| offsets = |
| llvm::SmallVector<mlir::LLVM::GEPArg>{op.getAddressPointIndex()}; |
| } |
| |
| if (eltType) |
| rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, targetType, eltType, |
| symAddr, offsets, true); |
| else |
| llvm_unreachable("Shouldn't ever be missing an eltType here"); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRStackSaveLowering |
| : public mlir::OpConversionPattern<mlir::cir::StackSaveOp> { |
| public: |
| using OpConversionPattern<mlir::cir::StackSaveOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::StackSaveOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto ptrTy = getTypeConverter()->convertType(op.getType()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy); |
| return mlir::success(); |
| } |
| }; |
| |
| #define GET_BUILTIN_LOWERING_CLASSES |
| #include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" |
| |
| class CIRUnreachableLowering |
| : public mlir::OpConversionPattern<mlir::cir::UnreachableOp> { |
| public: |
| using OpConversionPattern<mlir::cir::UnreachableOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::UnreachableOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UnreachableOp>(op); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRTrapLowering : public mlir::OpConversionPattern<mlir::cir::TrapOp> { |
| public: |
| using OpConversionPattern<mlir::cir::TrapOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::TrapOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto loc = op->getLoc(); |
| rewriter.eraseOp(op); |
| |
| rewriter.create<mlir::LLVM::Trap>(loc); |
| |
| // Note that the call to llvm.trap is not a terminator in LLVM dialect. |
| // So we must emit an additional llvm.unreachable to terminate the current |
| // block. |
| rewriter.create<mlir::LLVM::UnreachableOp>(loc); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRInlineAsmOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::InlineAsmOp> { |
| |
| using mlir::OpConversionPattern<mlir::cir::InlineAsmOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::InlineAsmOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Type llResTy; |
| if (op.getNumResults()) |
| llResTy = getTypeConverter()->convertType(op.getType(0)); |
| |
| auto dialect = op.getAsmFlavor(); |
| auto llDialect = dialect == mlir::cir::AsmFlavor::x86_att |
| ? mlir::LLVM::AsmDialect::AD_ATT |
| : mlir::LLVM::AsmDialect::AD_Intel; |
| |
| std::vector<mlir::Attribute> opAttrs; |
| auto llvmAttrName = mlir::LLVM::InlineAsmOp::getElementTypeAttrName(); |
| |
| // this is for the lowering to LLVM from LLVm dialect. Otherwise, if we |
| // don't have the result (i.e. void type as a result of operation), the |
| // element type attribute will be attached to the whole instruction, but not |
| // to the operand |
| if (!op.getNumResults()) |
| opAttrs.push_back(mlir::Attribute()); |
| |
| llvm::SmallVector<mlir::Value> llvmOperands; |
| llvm::SmallVector<mlir::Value> cirOperands; |
| for (size_t i = 0; i < op.getOperands().size(); ++i) { |
| auto llvmOps = adaptor.getOperands()[i]; |
| auto cirOps = op.getOperands()[i]; |
| llvmOperands.insert(llvmOperands.end(), llvmOps.begin(), llvmOps.end()); |
| cirOperands.insert(cirOperands.end(), cirOps.begin(), cirOps.end()); |
| } |
| |
| // so far we infer the llvm dialect element type attr from |
| // CIR operand type. |
| for (std::size_t i = 0; i < op.getOperandAttrs().size(); ++i) { |
| if (!op.getOperandAttrs()[i]) { |
| opAttrs.push_back(mlir::Attribute()); |
| continue; |
| } |
| |
| std::vector<mlir::NamedAttribute> attrs; |
| auto typ = cast<mlir::cir::PointerType>(cirOperands[i].getType()); |
| auto typAttr = mlir::TypeAttr::get( |
| getTypeConverter()->convertType(typ.getPointee())); |
| |
| attrs.push_back(rewriter.getNamedAttr(llvmAttrName, typAttr)); |
| auto newDict = rewriter.getDictionaryAttr(attrs); |
| opAttrs.push_back(newDict); |
| } |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::InlineAsmOp>( |
| op, llResTy, llvmOperands, op.getAsmStringAttr(), |
| op.getConstraintsAttr(), op.getSideEffectsAttr(), |
| /*is_align_stack*/ mlir::UnitAttr(), |
| mlir::LLVM::AsmDialectAttr::get(getContext(), llDialect), |
| rewriter.getArrayAttr(opAttrs)); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRPrefetchLowering |
| : public mlir::OpConversionPattern<mlir::cir::PrefetchOp> { |
| public: |
| using OpConversionPattern<mlir::cir::PrefetchOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::PrefetchOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| rewriter.replaceOpWithNewOp<mlir::LLVM::Prefetch>( |
| op, adaptor.getAddr(), adaptor.getIsWrite(), adaptor.getLocality(), |
| /*DataCache*/ 1); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRSetBitfieldLowering |
| : public mlir::OpConversionPattern<mlir::cir::SetBitfieldOp> { |
| public: |
| using OpConversionPattern<mlir::cir::SetBitfieldOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::SetBitfieldOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(op); |
| |
| auto info = op.getBitfieldInfo(); |
| auto size = info.getSize(); |
| auto offset = info.getOffset(); |
| auto storageType = info.getStorageType(); |
| auto context = storageType.getContext(); |
| |
| unsigned storageSize = 0; |
| |
| if (auto arTy = mlir::dyn_cast<mlir::cir::ArrayType>(storageType)) |
| storageSize = arTy.getSize() * 8; |
| else if (auto intTy = mlir::dyn_cast<mlir::cir::IntType>(storageType)) |
| storageSize = intTy.getWidth(); |
| else |
| llvm_unreachable( |
| "Either ArrayType or IntType expected for bitfields storage"); |
| |
| auto intType = mlir::IntegerType::get(context, storageSize); |
| auto srcVal = createIntCast(rewriter, adaptor.getSrc(), intType); |
| auto srcWidth = storageSize; |
| auto resultVal = srcVal; |
| |
| if (storageSize != size) { |
| assert(storageSize > size && "Invalid bitfield size."); |
| |
| mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( |
| op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0, |
| op.getIsVolatile()); |
| |
| srcVal = createAnd(rewriter, srcVal, |
| llvm::APInt::getLowBitsSet(srcWidth, size)); |
| resultVal = srcVal; |
| srcVal = createShL(rewriter, srcVal, offset); |
| |
| // Mask out the original value. |
| val = |
| createAnd(rewriter, val, |
| ~llvm::APInt::getBitsSet(srcWidth, offset, offset + size)); |
| |
| // Or together the unchanged values and the source value. |
| srcVal = rewriter.create<mlir::LLVM::OrOp>(op.getLoc(), val, srcVal); |
| } |
| |
| rewriter.create<mlir::LLVM::StoreOp>(op.getLoc(), srcVal, adaptor.getAddr(), |
| /* alignment */ 0, op.getIsVolatile()); |
| |
| auto resultTy = getTypeConverter()->convertType(op.getType()); |
| |
| resultVal = createIntCast(rewriter, resultVal, |
| mlir::cast<mlir::IntegerType>(resultTy)); |
| |
| if (info.getIsSigned()) { |
| assert(size <= storageSize); |
| unsigned highBits = storageSize - size; |
| |
| if (highBits) { |
| resultVal = createShL(rewriter, resultVal, highBits); |
| resultVal = createAShR(rewriter, resultVal, highBits); |
| } |
| } |
| |
| rewriter.replaceOp(op, resultVal); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRGetBitfieldLowering |
| : public mlir::OpConversionPattern<mlir::cir::GetBitfieldOp> { |
| public: |
| using OpConversionPattern<mlir::cir::GetBitfieldOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::GetBitfieldOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(op); |
| |
| auto info = op.getBitfieldInfo(); |
| auto size = info.getSize(); |
| auto offset = info.getOffset(); |
| auto storageType = info.getStorageType(); |
| auto context = storageType.getContext(); |
| unsigned storageSize = 0; |
| |
| if (auto arTy = mlir::dyn_cast<mlir::cir::ArrayType>(storageType)) |
| storageSize = arTy.getSize() * 8; |
| else if (auto intTy = mlir::dyn_cast<mlir::cir::IntType>(storageType)) |
| storageSize = intTy.getWidth(); |
| else |
| llvm_unreachable( |
| "Either ArrayType or IntType expected for bitfields storage"); |
| |
| auto intType = mlir::IntegerType::get(context, storageSize); |
| |
| mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>( |
| op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile()); |
| val = rewriter.create<mlir::LLVM::BitcastOp>(op.getLoc(), intType, val); |
| |
| if (info.getIsSigned()) { |
| assert(static_cast<unsigned>(offset + size) <= storageSize); |
| unsigned highBits = storageSize - offset - size; |
| val = createShL(rewriter, val, highBits); |
| val = createAShR(rewriter, val, offset + highBits); |
| } else { |
| val = createLShR(rewriter, val, offset); |
| |
| if (static_cast<unsigned>(offset) + size < storageSize) |
| val = createAnd(rewriter, val, |
| llvm::APInt::getLowBitsSet(storageSize, size)); |
| } |
| |
| auto resTy = getTypeConverter()->convertType(op.getType()); |
| auto newOp = |
| createIntCast(rewriter, val, mlir::cast<mlir::IntegerType>(resTy), |
| info.getIsSigned()); |
| rewriter.replaceOp(op, newOp); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRIsConstantOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::IsConstantOp> { |
| |
| using mlir::OpConversionPattern<mlir::cir::IsConstantOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::IsConstantOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // FIXME(cir): llvm.intr.is.constant returns i1 value but the LLVM Lowering |
| // expects that cir.bool type will be lowered as i8 type. |
| // So we have to insert zext here. |
| auto isConstantOP = rewriter.create<mlir::LLVM::IsConstantOp>( |
| op.getLoc(), adaptor.getVal()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(op, rewriter.getI8Type(), |
| isConstantOP); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRCmpThreeWayOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::CmpThreeWayOp> { |
| public: |
| using mlir::OpConversionPattern< |
| mlir::cir::CmpThreeWayOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CmpThreeWayOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| if (!op.isIntegralComparison() || !op.isStrongOrdering()) { |
| op.emitError() << "unsupported three-way comparison type"; |
| return mlir::failure(); |
| } |
| |
| auto cmpInfo = op.getInfo(); |
| assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && |
| cmpInfo.getGt() == 1); |
| |
| auto operandTy = mlir::cast<mlir::cir::IntType>(op.getLhs().getType()); |
| auto resultTy = op.getType(); |
| auto llvmIntrinsicName = getLLVMIntrinsicName( |
| operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth()); |
| |
| rewriter.setInsertionPoint(op); |
| |
| auto llvmLhs = adaptor.getLhs(); |
| auto llvmRhs = adaptor.getRhs(); |
| auto llvmResultTy = getTypeConverter()->convertType(resultTy); |
| auto callIntrinsicOp = |
| createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName, |
| llvmResultTy, {llvmLhs, llvmRhs}); |
| |
| rewriter.replaceOp(op, callIntrinsicOp); |
| return mlir::success(); |
| } |
| |
| private: |
| static std::string getLLVMIntrinsicName(bool signedCmp, unsigned operandWidth, |
| unsigned resultWidth) { |
| // The intrinsic's name takes the form: |
| // `llvm.<scmp|ucmp>.i<resultWidth>.i<operandWidth>` |
| |
| std::string result = "llvm."; |
| |
| if (signedCmp) |
| result.append("scmp."); |
| else |
| result.append("ucmp."); |
| |
| // Result type part. |
| result.push_back('i'); |
| result.append(std::to_string(resultWidth)); |
| result.push_back('.'); |
| |
| // Operand type part. |
| result.push_back('i'); |
| result.append(std::to_string(operandWidth)); |
| |
| return result; |
| } |
| }; |
| |
| class CIRClearCacheOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ClearCacheOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ClearCacheOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ClearCacheOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto begin = adaptor.getBegin(); |
| auto end = adaptor.getEnd(); |
| auto intrinNameAttr = |
| mlir::StringAttr::get(op.getContext(), "llvm.clear_cache"); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallIntrinsicOp>( |
| op, mlir::Type{}, intrinNameAttr, mlir::ValueRange{begin, end}); |
| |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRUndefOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::UndefOp> { |
| |
| using mlir::OpConversionPattern<mlir::cir::UndefOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::UndefOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto typ = getTypeConverter()->convertType(op.getRes().getType()); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::UndefOp>(op, typ); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIREhTypeIdOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::EhTypeIdOp> { |
| public: |
| using OpConversionPattern<mlir::cir::EhTypeIdOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::EhTypeIdOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>( |
| op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), |
| op.getTypeSymAttr()); |
| mlir::LLVM::CallIntrinsicOp newOp = createCallLLVMIntrinsicOp( |
| rewriter, op.getLoc(), "llvm.eh.typeid.for.p0", rewriter.getI32Type(), |
| mlir::ValueRange{addrOp}); |
| rewriter.replaceOp(op, newOp); |
| return mlir::success(); |
| } |
| }; |
| |
| // Make sure the LLVM function we are about to create a call for actually |
| // exists, if not create one. Returns a function |
| void getOrCreateLLVMFuncOp(mlir::ConversionPatternRewriter &rewriter, |
| mlir::Location loc, mlir::ModuleOp mod, |
| mlir::LLVM::LLVMFuncOp enclosingfnOp, |
| llvm::StringRef fnName, mlir::Type fnTy) { |
| auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(mod, fnName); |
| if (!sourceSymbol) { |
| mlir::OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPoint(enclosingfnOp); |
| rewriter.create<mlir::LLVM::LLVMFuncOp>(loc, fnName, fnTy); |
| } |
| } |
| |
| class CIRCatchParamOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::CatchParamOp> { |
| public: |
| using OpConversionPattern<mlir::cir::CatchParamOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::CatchParamOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| auto modOp = op->getParentOfType<mlir::ModuleOp>(); |
| auto enclosingFnOp = op->getParentOfType<mlir::LLVM::LLVMFuncOp>(); |
| if (op.isBegin()) { |
| // Get or create `declare ptr @__cxa_begin_catch(ptr)` |
| StringRef fnName = "__cxa_begin_catch"; |
| auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {llvmPtrTy}, |
| /*isVarArg=*/false); |
| getOrCreateLLVMFuncOp(rewriter, op.getLoc(), modOp, enclosingFnOp, fnName, |
| fnTy); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| op, mlir::TypeRange{llvmPtrTy}, fnName, |
| mlir::ValueRange{adaptor.getExceptionPtr()}); |
| return mlir::success(); |
| } else if (op.isEnd()) { |
| StringRef fnName = "__cxa_end_catch"; |
| auto fnTy = mlir::LLVM::LLVMFunctionType::get( |
| mlir::LLVM::LLVMVoidType::get(rewriter.getContext()), {}, |
| /*isVarArg=*/false); |
| getOrCreateLLVMFuncOp(rewriter, op.getLoc(), modOp, enclosingFnOp, fnName, |
| fnTy); |
| rewriter.create<mlir::LLVM::CallOp>(op.getLoc(), mlir::TypeRange{}, |
| fnName, mlir::ValueRange{}); |
| rewriter.eraseOp(op); |
| return mlir::success(); |
| } |
| llvm_unreachable("only begin/end supposed to make to lowering stage"); |
| return mlir::failure(); |
| } |
| }; |
| |
| class CIRResumeOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ResumeOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ResumeOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ResumeOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // %lpad.val = insertvalue { ptr, i32 } poison, ptr %exception_ptr, 0 |
| // %lpad.val2 = insertvalue { ptr, i32 } %lpad.val, i32 %selector, 1 |
| // resume { ptr, i32 } %lpad.val2 |
| SmallVector<int64_t> slotIdx = {0}; |
| SmallVector<int64_t> selectorIdx = {1}; |
| auto llvmLandingPadStructTy = getLLVMLandingPadStructTy(rewriter); |
| mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>( |
| op.getLoc(), llvmLandingPadStructTy); |
| |
| mlir::Value slot = rewriter.create<mlir::LLVM::InsertValueOp>( |
| op.getLoc(), poison, adaptor.getExceptionPtr(), slotIdx); |
| mlir::Value selector = rewriter.create<mlir::LLVM::InsertValueOp>( |
| op.getLoc(), slot, adaptor.getTypeId(), selectorIdx); |
| |
| rewriter.replaceOpWithNewOp<mlir::LLVM::ResumeOp>(op, selector); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRAllocExceptionOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::AllocExceptionOp> { |
| public: |
| using OpConversionPattern<mlir::cir::AllocExceptionOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::AllocExceptionOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Get or create `declare ptr @__cxa_allocate_exception(i64)` |
| StringRef fnName = "__cxa_allocate_exception"; |
| auto modOp = op->getParentOfType<mlir::ModuleOp>(); |
| auto enclosingFnOp = op->getParentOfType<mlir::LLVM::LLVMFuncOp>(); |
| auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| auto int64Ty = mlir::IntegerType::get(rewriter.getContext(), 64); |
| auto fnTy = mlir::LLVM::LLVMFunctionType::get(llvmPtrTy, {int64Ty}, |
| /*isVarArg=*/false); |
| getOrCreateLLVMFuncOp(rewriter, op.getLoc(), modOp, enclosingFnOp, fnName, |
| fnTy); |
| auto size = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), |
| adaptor.getSizeAttr()); |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| op, mlir::TypeRange{llvmPtrTy}, fnName, mlir::ValueRange{size}); |
| return mlir::success(); |
| } |
| }; |
| |
| class CIRThrowOpLowering |
| : public mlir::OpConversionPattern<mlir::cir::ThrowOp> { |
| public: |
| using OpConversionPattern<mlir::cir::ThrowOp>::OpConversionPattern; |
| |
| mlir::LogicalResult |
| matchAndRewrite(mlir::cir::ThrowOp op, OpAdaptor adaptor, |
| mlir::ConversionPatternRewriter &rewriter) const override { |
| // Get or create `declare void @__cxa_throw(ptr, ptr, ptr)` |
| StringRef fnName = "__cxa_throw"; |
| auto modOp = op->getParentOfType<mlir::ModuleOp>(); |
| auto enclosingFnOp = op->getParentOfType<mlir::LLVM::LLVMFuncOp>(); |
| auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); |
| auto voidTy = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); |
| auto fnTy = mlir::LLVM::LLVMFunctionType::get( |
| voidTy, {llvmPtrTy, llvmPtrTy, llvmPtrTy}, |
| /*isVarArg=*/false); |
| getOrCreateLLVMFuncOp(rewriter, op.getLoc(), modOp, enclosingFnOp, fnName, |
| fnTy); |
| mlir::Value typeInfo = rewriter.create<mlir::LLVM::AddressOfOp>( |
| op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), |
| adaptor.getTypeInfoAttr()); |
| |
| mlir::Value dtor; |
| if (op.getDtor()) { |
| dtor = rewriter.create<mlir::LLVM::AddressOfOp>( |
| op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext()), |
| adaptor.getDtorAttr()); |
| } else { |
| dtor = rewriter.create<mlir::LLVM::ZeroOp>( |
| op.getLoc(), mlir::LLVM::LLVMPointerType::get(rewriter.getContext())); |
| } |
| rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>( |
| op, mlir::TypeRange{}, fnName, |
| mlir::ValueRange{adaptor.getExceptionPtr(), typeInfo, dtor}); |
| return mlir::success(); |
| } |
| }; |
| |
| void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, |
| mlir::TypeConverter &converter, |
| mlir::DataLayout &dataLayout) { |
| patterns.add<CIRReturnLowering>(patterns.getContext()); |
| patterns.add<CIRAllocaLowering>(converter, dataLayout, patterns.getContext()); |
| patterns.add< |
| CIRCmpOpLowering, CIRSelectOpLowering, CIRBitClrsbOpLowering, |
| CIRBitClzOpLowering, CIRBitCtzOpLowering, CIRBitFfsOpLowering, |
| CIRBitParityOpLowering, CIRBitPopcountOpLowering, |
| CIRAtomicCmpXchgLowering, CIRAtomicXchgLowering, CIRAtomicFetchLowering, |
| CIRByteswapOpLowering, CIRRotateOpLowering, CIRBrCondOpLowering, |
| CIRPtrStrideOpLowering, CIRCallLowering, CIRTryCallLowering, |
| CIREhInflightOpLowering, CIRUnaryOpLowering, CIRBinOpLowering, |
| CIRBinOpOverflowOpLowering, CIRShiftOpLowering, CIRLoadLowering, |
| CIRConstantLowering, CIRStoreLowering, CIRFuncLowering, CIRCastOpLowering, |
| CIRGlobalOpLowering, CIRGetGlobalOpLowering, CIRComplexCreateOpLowering, |
| CIRComplexRealOpLowering, CIRComplexImagOpLowering, |
| CIRComplexRealPtrOpLowering, CIRComplexImagPtrOpLowering, |
| CIRVAStartLowering, CIRVAEndLowering, CIRVACopyLowering, CIRVAArgLowering, |
| CIRBrOpLowering, CIRGetMemberOpLowering, CIRGetRuntimeMemberOpLowering, |
| CIRSwitchFlatOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, |
| CIRMemCpyOpLowering, CIRFAbsOpLowering, CIRExpectOpLowering, |
| CIRVTableAddrPointOpLowering, CIRVectorCreateLowering, |
| CIRVectorCmpOpLowering, CIRVectorSplatLowering, CIRVectorTernaryLowering, |
| CIRVectorShuffleIntsLowering, CIRVectorShuffleVecLowering, |
| CIRStackSaveLowering, CIRUnreachableLowering, CIRTrapLowering, |
| CIRInlineAsmOpLowering, CIRSetBitfieldLowering, CIRGetBitfieldLowering, |
| CIRPrefetchLowering, CIRObjSizeOpLowering, CIRIsConstantOpLowering, |
| CIRCmpThreeWayOpLowering, CIRClearCacheOpLowering, CIRUndefOpLowering, |
| CIREhTypeIdOpLowering, CIRCatchParamOpLowering, CIRResumeOpLowering, |
| CIRAllocExceptionOpLowering, CIRThrowOpLowering |
| #define GET_BUILTIN_LOWERING_LIST |
| #include "clang/CIR/Dialect/IR/CIRBuiltinsLowering.inc" |
| #undef GET_BUILTIN_LOWERING_LIST |
| >(converter, patterns.getContext()); |
| } |
| |
| namespace { |
| |
| std::unique_ptr<mlir::cir::LowerModule> |
| prepareLowerModule(mlir::ModuleOp module) { |
| mlir::PatternRewriter rewriter{module->getContext()}; |
| // If the triple is not present, e.g. CIR modules parsed from text, we |
| // cannot init LowerModule properly. |
| assert(!::cir::MissingFeatures::makeTripleAlwaysPresent()); |
| if (!module->hasAttr("cir.triple")) |
| return {}; |
| return mlir::cir::createLowerModule(module, rewriter); |
| } |
| |
| // FIXME: change the type of lowerModule to `LowerModule &` to have better |
| // lambda capturing experience. Also blocked by makeTripleAlwaysPresent. |
| void prepareTypeConverter(mlir::LLVMTypeConverter &converter, |
| mlir::DataLayout &dataLayout, |
| mlir::cir::LowerModule *lowerModule) { |
| converter.addConversion([&, lowerModule]( |
| mlir::cir::PointerType type) -> mlir::Type { |
| // Drop pointee type since LLVM dialect only allows opaque pointers. |
| |
| auto addrSpace = |
| mlir::cast_if_present<mlir::cir::AddressSpaceAttr>(type.getAddrSpace()); |
| // Null addrspace attribute indicates the default addrspace. |
| if (!addrSpace) |
| return mlir::LLVM::LLVMPointerType::get(type.getContext()); |
| |
| assert(lowerModule && "CIR AS map is not available"); |
| // Pass through target addrspace and map CIR addrspace to LLVM addrspace by |
| // querying the target info. |
| unsigned targetAS = |
| addrSpace.isTarget() |
| ? addrSpace.getTargetValue() |
| : lowerModule->getTargetLoweringInfo() |
| .getTargetAddrSpaceFromCIRAddrSpace(addrSpace); |
| |
| return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS); |
| }); |
| converter.addConversion([&](mlir::cir::DataMemberType type) -> mlir::Type { |
| return mlir::IntegerType::get(type.getContext(), |
| dataLayout.getTypeSizeInBits(type)); |
| }); |
| converter.addConversion([&](mlir::cir::ArrayType type) -> mlir::Type { |
| auto ty = converter.convertType(type.getEltType()); |
| return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); |
| }); |
| converter.addConversion([&](mlir::cir::VectorType type) -> mlir::Type { |
| auto ty = converter.convertType(type.getEltType()); |
| return mlir::LLVM::getFixedVectorType(ty, type.getSize()); |
| }); |
| converter.addConversion([&](mlir::cir::BoolType type) -> mlir::Type { |
| return mlir::IntegerType::get(type.getContext(), 8, |
| mlir::IntegerType::Signless); |
| }); |
| converter.addConversion([&](mlir::cir::IntType type) -> mlir::Type { |
| // LLVM doesn't work with signed types, so we drop the CIR signs here. |
| return mlir::IntegerType::get(type.getContext(), type.getWidth()); |
| }); |
| converter.addConversion([&](mlir::cir::SingleType type) -> mlir::Type { |
| return mlir::FloatType::getF32(type.getContext()); |
| }); |
| converter.addConversion([&](mlir::cir::DoubleType type) -> mlir::Type { |
| return mlir::FloatType::getF64(type.getContext()); |
| }); |
| converter.addConversion([&](mlir::cir::FP80Type type) -> mlir::Type { |
| return mlir::FloatType::getF80(type.getContext()); |
| }); |
| converter.addConversion([&](mlir::cir::LongDoubleType type) -> mlir::Type { |
| return converter.convertType(type.getUnderlying()); |
| }); |
| converter.addConversion([&](mlir::cir::FP16Type type) -> mlir::Type { |
| return mlir::FloatType::getF16(type.getContext()); |
| }); |
| converter.addConversion([&](mlir::cir::BF16Type type) -> mlir::Type { |
| return mlir::FloatType::getBF16(type.getContext()); |
| }); |
| converter.addConversion([&](mlir::cir::ComplexType type) -> mlir::Type { |
| // A complex type is lowered to an LLVM struct that contains the real and |
| // imaginary part as data fields. |
| mlir::Type elementTy = converter.convertType(type.getElementTy()); |
| mlir::Type structFields[2] = {elementTy, elementTy}; |
| return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(), |
| structFields); |
| }); |
| converter.addConversion([&](mlir::cir::FuncType type) -> mlir::Type { |
| auto result = converter.convertType(type.getReturnType()); |
| llvm::SmallVector<mlir::Type> arguments; |
| if (converter.convertTypes(type.getInputs(), arguments).failed()) |
| llvm_unreachable("Failed to convert function type parameters"); |
| auto varArg = type.isVarArg(); |
| return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg); |
| }); |
| converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type { |
| // FIXME(cir): create separate unions, struct, and classes types. |
| // Convert struct members. |
| llvm::SmallVector<mlir::Type> llvmMembers; |
| switch (type.getKind()) { |
| case mlir::cir::StructType::Class: |
| // TODO(cir): This should be properly validated. |
| case mlir::cir::StructType::Struct: |
| for (auto ty : type.getMembers()) |
| llvmMembers.push_back(converter.convertType(ty)); |
| break; |
| // Unions are lowered as only the largest member. |
| case mlir::cir::StructType::Union: { |
| auto largestMember = type.getLargestMember(dataLayout); |
| if (largestMember) |
| llvmMembers.push_back(converter.convertType(largestMember)); |
| break; |
| } |
| } |
| |
| // Struct has a name: lower as an identified struct. |
| mlir::LLVM::LLVMStructType llvmStruct; |
| if (type.getName()) { |
| llvmStruct = mlir::LLVM::LLVMStructType::getIdentified( |
| type.getContext(), type.getPrefixedName()); |
| if (llvmStruct.setBody(llvmMembers, /*isPacked=*/type.getPacked()) |
| .failed()) |
| llvm_unreachable("Failed to set body of struct"); |
| } else { // Struct has no name: lower as literal struct. |
| llvmStruct = mlir::LLVM::LLVMStructType::getLiteral( |
| type.getContext(), llvmMembers, /*isPacked=*/type.getPacked()); |
| } |
| |
| return llvmStruct; |
| }); |
| converter.addConversion([&](mlir::cir::VoidType type) -> mlir::Type { |
| return mlir::LLVM::LLVMVoidType::get(type.getContext()); |
| }); |
| } |
| } // namespace |
| |
| static void buildCtorDtorList( |
| mlir::ModuleOp module, StringRef globalXtorName, StringRef llvmXtorName, |
| llvm::function_ref<std::pair<StringRef, int>(mlir::Attribute)> createXtor) { |
| llvm::SmallVector<std::pair<StringRef, int>, 2> globalXtors; |
| for (auto namedAttr : module->getAttrs()) { |
| if (namedAttr.getName() == globalXtorName) { |
| for (auto attr : mlir::cast<mlir::ArrayAttr>(namedAttr.getValue())) |
| globalXtors.emplace_back(createXtor(attr)); |
| break; |
| } |
| } |
| |
| if (globalXtors.empty()) |
| return; |
| |
| mlir::OpBuilder builder(module.getContext()); |
| builder.setInsertionPointToEnd(&module.getBodyRegion().back()); |
| |
| // Create a global array llvm.global_ctors with element type of |
| // struct { i32, ptr, ptr } |
| auto CtorPFTy = mlir::LLVM::LLVMPointerType::get(builder.getContext()); |
| llvm::SmallVector<mlir::Type> CtorStructFields; |
| CtorStructFields.push_back(builder.getI32Type()); |
| CtorStructFields.push_back(CtorPFTy); |
| CtorStructFields.push_back(CtorPFTy); |
| |
| auto CtorStructTy = mlir::LLVM::LLVMStructType::getLiteral( |
| builder.getContext(), CtorStructFields); |
| auto CtorStructArrayTy = |
| mlir::LLVM::LLVMArrayType::get(CtorStructTy, globalXtors.size()); |
| |
| auto loc = module.getLoc(); |
| auto newGlobalOp = builder.create<mlir::LLVM::GlobalOp>( |
| loc, CtorStructArrayTy, true, mlir::LLVM::Linkage::Appending, |
| llvmXtorName, mlir::Attribute()); |
| |
| newGlobalOp.getRegion().push_back(new mlir::Block()); |
| builder.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); |
| |
| mlir::Value result = |
| builder.create<mlir::LLVM::UndefOp>(loc, CtorStructArrayTy); |
| |
| for (uint64_t I = 0; I < globalXtors.size(); I++) { |
| auto fn = globalXtors[I]; |
| mlir::Value structInit = |
| builder.create<mlir::LLVM::UndefOp>(loc, CtorStructTy); |
| mlir::Value initPriority = builder.create<mlir::LLVM::ConstantOp>( |
| loc, CtorStructFields[0], fn.second); |
| mlir::Value initFuncAddr = builder.create<mlir::LLVM::AddressOfOp>( |
| loc, CtorStructFields[1], fn.first); |
| mlir::Value initAssociate = |
| builder.create<mlir::LLVM::ZeroOp>(loc, CtorStructFields[2]); |
| structInit = builder.create<mlir::LLVM::InsertValueOp>(loc, structInit, |
| initPriority, 0); |
| structInit = builder.create<mlir::LLVM::InsertValueOp>(loc, structInit, |
| initFuncAddr, 1); |
| // TODO: handle associated data for initializers. |
| structInit = builder.create<mlir::LLVM::InsertValueOp>(loc, structInit, |
| initAssociate, 2); |
| result = |
| builder.create<mlir::LLVM::InsertValueOp>(loc, result, structInit, I); |
| } |
| |
| builder.create<mlir::LLVM::ReturnOp>(loc, result); |
| } |
| |
| // The unreachable code is not lowered by applyPartialConversion function |
| // since it traverses blocks in the dominance order. At the same time we |
| // do need to lower such code - otherwise verification errors occur. |
| // For instance, the next CIR code: |
| // |
| // cir.func @foo(%arg0: !s32i) -> !s32i { |
| // %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool |
| // cir.if %4 { |
| // %5 = cir.const #cir.int<1> : !s32i |
| // cir.return %5 : !s32i |
| // } else { |
| // %5 = cir.const #cir.int<0> : !s32i |
| // cir.return %5 : !s32i |
| // } |
| // cir.return %arg0 : !s32i |
| // } |
| // |
| // contains an unreachable return operation (the last one). After the flattening |
| // pass it will be placed into the unreachable block. And the possible error |
| // after the lowering pass is: error: 'cir.return' op expects parent op to be |
| // one of 'cir.func, cir.scope, cir.if ... The reason that this operation was |
| // not lowered and the new parent is llvm.func. |
| // |
| // In the future we may want to get rid of this function and use DCE pass or |
| // something similar. But now we need to guarantee the absence of the dialect |
| // verification errors. |
| void collect_unreachable(mlir::Operation *parent, |
| llvm::SmallVector<mlir::Operation *> &ops) { |
| |
| llvm::SmallVector<mlir::Block *> unreachable_blocks; |
| parent->walk([&](mlir::Block *blk) { // check |
| if (blk->hasNoPredecessors() && !blk->isEntryBlock()) |
| unreachable_blocks.push_back(blk); |
| }); |
| |
| std::set<mlir::Block *> visited; |
| for (auto *root : unreachable_blocks) { |
| // We create a work list for each unreachable block. |
| // Thus we traverse operations in some order. |
| std::deque<mlir::Block *> workList; |
| workList.push_back(root); |
| |
| while (!workList.empty()) { |
| auto *blk = workList.back(); |
| workList.pop_back(); |
| if (visited.count(blk)) |
| continue; |
| visited.emplace(blk); |
| |
| for (auto &op : *blk) |
| ops.push_back(&op); |
| |
| for (auto it = blk->succ_begin(); it != blk->succ_end(); ++it) |
| workList.push_back(*it); |
| } |
| } |
| } |
| |
| void ConvertCIRToLLVMPass::runOnOperation() { |
| auto module = getOperation(); |
| mlir::DataLayout dataLayout(module); |
| mlir::LLVMTypeConverter converter(&getContext()); |
| std::unique_ptr<mlir::cir::LowerModule> lowerModule = |
| prepareLowerModule(module); |
| prepareTypeConverter(converter, dataLayout, lowerModule.get()); |
| |
| mlir::RewritePatternSet patterns(&getContext()); |
| |
| populateCIRToLLVMConversionPatterns(patterns, converter, dataLayout); |
| mlir::populateFuncToLLVMConversionPatterns(converter, patterns); |
| |
| mlir::ConversionTarget target(getContext()); |
| using namespace mlir::cir; |
| // clang-format off |
| target.addLegalOp<mlir::ModuleOp |
| // ,AllocaOp |
| // ,BrCondOp |
| // ,BrOp |
| // ,CallOp |
| // ,CastOp |
| // ,CmpOp |
| // ,ConstantOp |
| // ,FuncOp |
| // ,LoadOp |
| // ,ReturnOp |
| // ,StoreOp |
| // ,YieldOp |
| >(); |
| // clang-format on |
| target.addLegalDialect<mlir::LLVM::LLVMDialect>(); |
| target.addIllegalDialect<mlir::BuiltinDialect, mlir::cir::CIRDialect, |
| mlir::func::FuncDialect>(); |
| |
| // Allow operations that will be lowered directly to LLVM IR. |
| target.addLegalOp<mlir::LLVM::ZeroOp>(); |
| |
| getOperation()->removeAttr("cir.sob"); |
| getOperation()->removeAttr("cir.lang"); |
| |
| llvm::SmallVector<mlir::Operation *> ops; |
| ops.push_back(module); |
| collect_unreachable(module, ops); |
| |
| if (failed(applyPartialConversion(ops, target, std::move(patterns)))) |
| signalPassFailure(); |
| |
| // Emit the llvm.global_ctors array. |
| buildCtorDtorList( |
| module, "cir.global_ctors", "llvm.global_ctors", |
| [](mlir::Attribute attr) { |
| assert(mlir::isa<mlir::cir::GlobalCtorAttr>(attr) && |
| "must be a GlobalCtorAttr"); |
| auto ctorAttr = mlir::cast<mlir::cir::GlobalCtorAttr>(attr); |
| return std::make_pair(ctorAttr.getName(), ctorAttr.getPriority()); |
| }); |
| // Emit the llvm.global_dtors array. |
| buildCtorDtorList( |
| module, "cir.global_dtors", "llvm.global_dtors", |
| [](mlir::Attribute attr) { |
| assert(mlir::isa<mlir::cir::GlobalDtorAttr>(attr) && |
| "must be a GlobalDtorAttr"); |
| auto dtorAttr = mlir::cast<mlir::cir::GlobalDtorAttr>(attr); |
| return std::make_pair(dtorAttr.getName(), dtorAttr.getPriority()); |
| }); |
| } |
| |
| std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() { |
| return std::make_unique<ConvertCIRToLLVMPass>(); |
| } |
| |
| void populateCIRToLLVMPasses(mlir::OpPassManager &pm) { |
| populateCIRPreLoweringPasses(pm); |
| pm.addPass(createConvertCIRToLLVMPass()); |
| } |
| |
| extern void registerCIRDialectTranslation(mlir::MLIRContext &context); |
| |
| std::unique_ptr<llvm::Module> |
| lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp theModule, LLVMContext &llvmCtx, |
| bool disableVerifier) { |
| mlir::MLIRContext *mlirCtx = theModule.getContext(); |
| mlir::PassManager pm(mlirCtx); |
| populateCIRToLLVMPasses(pm); |
| |
| // This is necessary to have line tables emitted and basic |
| // debugger working. In the future we will add proper debug information |
| // emission directly from our frontend. |
| pm.addPass(mlir::LLVM::createDIScopeForLLVMFuncOpPass()); |
| |
| // FIXME(cir): this shouldn't be necessary. It's meant to be a temporary |
| // workaround until we understand why some unrealized casts are being |
| // emmited and how to properly avoid them. |
| pm.addPass(mlir::createReconcileUnrealizedCastsPass()); |
| |
| pm.enableVerifier(!disableVerifier); |
| (void)mlir::applyPassManagerCLOptions(pm); |
| |
| auto result = !mlir::failed(pm.run(theModule)); |
| if (!result) |
| report_fatal_error( |
| "The pass manager failed to lower CIR to LLVMIR dialect!"); |
| |
| // Now that we ran all the lowering passes, verify the final output. |
| if (theModule.verify().failed()) |
| report_fatal_error("Verification of the final LLVMIR dialect failed!"); |
| |
| mlir::registerBuiltinDialectTranslation(*mlirCtx); |
| mlir::registerLLVMDialectTranslation(*mlirCtx); |
| mlir::registerOpenMPDialectTranslation(*mlirCtx); |
| registerCIRDialectTranslation(*mlirCtx); |
| |
| auto ModuleName = theModule.getName(); |
| auto llvmModule = mlir::translateModuleToLLVMIR( |
| theModule, llvmCtx, ModuleName ? *ModuleName : "CIRToLLVMModule"); |
| |
| if (!llvmModule) |
| report_fatal_error("Lowering from LLVMIR dialect to llvm IR failed!"); |
| |
| return llvmModule; |
| } |
| } // namespace direct |
| } // namespace cir |