| //===- ControlFlowToLLVM.cpp - ControlFlow to LLVM dialect conversion -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a pass to convert MLIR standard and builtin dialects |
| // into the LLVM IR dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" |
| |
| #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
| #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
| #include "mlir/Conversion/LLVMCommon/Pattern.h" |
| #include "mlir/Conversion/LLVMCommon/PrintCallHelper.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/DialectConversion.h" |
| |
| namespace mlir { |
| #define GEN_PASS_DEF_CONVERTCONTROLFLOWTOLLVMPASS |
| #include "mlir/Conversion/Passes.h.inc" |
| } // namespace mlir |
| |
| using namespace mlir; |
| |
| #define PASS_NAME "convert-cf-to-llvm" |
| |
| namespace { |
| /// Lower `cf.assert`. The default lowering calls the `abort` function if the |
| /// assertion is violated and has no effect otherwise. The failure message is |
| /// ignored by the default lowering but should be propagated by any custom |
| /// lowering. |
| struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> { |
| explicit AssertOpLowering(const LLVMTypeConverter &typeConverter, |
| bool abortOnFailedAssert = true, |
| SymbolTableCollection *symbolTables = nullptr) |
| : ConvertOpToLLVMPattern<cf::AssertOp>(typeConverter, /*benefit=*/1), |
| abortOnFailedAssert(abortOnFailedAssert), symbolTables(symbolTables) {} |
| |
| LogicalResult |
| matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| auto loc = op.getLoc(); |
| auto module = op->getParentOfType<ModuleOp>(); |
| |
| // Split block at `assert` operation. |
| Block *opBlock = rewriter.getInsertionBlock(); |
| auto opPosition = rewriter.getInsertionPoint(); |
| Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); |
| |
| // Failed block: Generate IR to print the message and call `abort`. |
| Block *failureBlock = rewriter.createBlock(opBlock->getParent()); |
| auto createResult = LLVM::createPrintStrCall( |
| rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), |
| /*addNewLine=*/false, |
| /*runtimeFunctionName=*/"puts", symbolTables); |
| if (createResult.failed()) |
| return failure(); |
| |
| if (abortOnFailedAssert) { |
| // Insert the `abort` declaration if necessary. |
| auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort"); |
| if (!abortFunc) { |
| OpBuilder::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(module.getBody()); |
| auto abortFuncTy = LLVM::LLVMFunctionType::get(getVoidType(), {}); |
| abortFunc = LLVM::LLVMFuncOp::create(rewriter, rewriter.getUnknownLoc(), |
| "abort", abortFuncTy); |
| } |
| LLVM::CallOp::create(rewriter, loc, abortFunc, ValueRange()); |
| LLVM::UnreachableOp::create(rewriter, loc); |
| } else { |
| LLVM::BrOp::create(rewriter, loc, ValueRange(), continuationBlock); |
| } |
| |
| // Generate assertion test. |
| rewriter.setInsertionPointToEnd(opBlock); |
| rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
| op, adaptor.getArg(), continuationBlock, failureBlock); |
| |
| return success(); |
| } |
| |
| private: |
| /// If set to `false`, messages are printed but program execution continues. |
| /// This is useful for testing asserts. |
| bool abortOnFailedAssert = true; |
| |
| SymbolTableCollection *symbolTables = nullptr; |
| }; |
| |
| /// Helper function for converting branch ops. This function converts the |
| /// signature of the given block. If the new block signature is different from |
| /// `expectedTypes`, returns "failure". |
| static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, |
| const TypeConverter *converter, |
| Operation *branchOp, Block *block, |
| TypeRange expectedTypes) { |
| assert(converter && "expected non-null type converter"); |
| assert(!block->isEntryBlock() && "entry blocks have no predecessors"); |
| |
| // There is nothing to do if the types already match. |
| if (block->getArgumentTypes() == expectedTypes) |
| return block; |
| |
| // Compute the new block argument types and convert the block. |
| std::optional<TypeConverter::SignatureConversion> conversion = |
| converter->convertBlockSignature(block); |
| if (!conversion) |
| return rewriter.notifyMatchFailure(branchOp, |
| "could not compute block signature"); |
| if (expectedTypes != conversion->getConvertedTypes()) |
| return rewriter.notifyMatchFailure( |
| branchOp, |
| "mismatch between adaptor operand types and computed block signature"); |
| return rewriter.applySignatureConversion(block, *conversion, converter); |
| } |
| |
| /// Convert the destination block signature (if necessary) and lower the branch |
| /// op to llvm.br. |
| struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { |
| using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| FailureOr<Block *> convertedBlock = |
| getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), |
| TypeRange(adaptor.getOperands())); |
| if (failed(convertedBlock)) |
| return failure(); |
| DictionaryAttr attrs = op->getAttrDictionary(); |
| Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( |
| op, adaptor.getOperands(), *convertedBlock); |
| // TODO: We should not just forward all attributes like that. But there are |
| // existing Flang tests that depend on this behavior. |
| newOp->setAttrs(attrs); |
| return success(); |
| } |
| }; |
| |
| /// Convert the destination block signatures (if necessary) and lower the |
| /// branch op to llvm.cond_br. |
| struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { |
| using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(cf::CondBranchOp op, |
| typename cf::CondBranchOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| FailureOr<Block *> convertedTrueBlock = |
| getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), |
| TypeRange(adaptor.getTrueDestOperands())); |
| if (failed(convertedTrueBlock)) |
| return failure(); |
| FailureOr<Block *> convertedFalseBlock = |
| getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), |
| TypeRange(adaptor.getFalseDestOperands())); |
| if (failed(convertedFalseBlock)) |
| return failure(); |
| DictionaryAttr attrs = op->getAttrDictionary(); |
| auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( |
| op, adaptor.getCondition(), adaptor.getTrueDestOperands(), |
| adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(), |
| *convertedTrueBlock, *convertedFalseBlock); |
| // TODO: We should not just forward all attributes like that. But there are |
| // existing Flang tests that depend on this behavior. |
| newOp->setAttrs(attrs); |
| return success(); |
| } |
| }; |
| |
| /// Convert the destination block signatures (if necessary) and lower the |
| /// switch op to llvm.switch. |
| struct SwitchOpLowering : public ConvertOpToLLVMPattern<cf::SwitchOp> { |
| using ConvertOpToLLVMPattern<cf::SwitchOp>::ConvertOpToLLVMPattern; |
| |
| LogicalResult |
| matchAndRewrite(cf::SwitchOp op, typename cf::SwitchOp::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| // Get or convert default block. |
| FailureOr<Block *> convertedDefaultBlock = getConvertedBlock( |
| rewriter, getTypeConverter(), op, op.getDefaultDestination(), |
| TypeRange(adaptor.getDefaultOperands())); |
| if (failed(convertedDefaultBlock)) |
| return failure(); |
| |
| // Get or convert all case blocks. |
| SmallVector<Block *> caseDestinations; |
| SmallVector<ValueRange> caseOperands = adaptor.getCaseOperands(); |
| for (auto it : llvm::enumerate(op.getCaseDestinations())) { |
| Block *b = it.value(); |
| FailureOr<Block *> convertedBlock = |
| getConvertedBlock(rewriter, getTypeConverter(), op, b, |
| TypeRange(caseOperands[it.index()])); |
| if (failed(convertedBlock)) |
| return failure(); |
| caseDestinations.push_back(*convertedBlock); |
| } |
| |
| rewriter.replaceOpWithNewOp<LLVM::SwitchOp>( |
| op, adaptor.getFlag(), *convertedDefaultBlock, |
| adaptor.getDefaultOperands(), adaptor.getCaseValuesAttr(), |
| caseDestinations, caseOperands); |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::cf::populateControlFlowToLLVMConversionPatterns( |
| const LLVMTypeConverter &converter, RewritePatternSet &patterns) { |
| // clang-format off |
| patterns.add< |
| BranchOpLowering, |
| CondBranchOpLowering, |
| SwitchOpLowering>(converter); |
| // clang-format on |
| } |
| |
| void mlir::cf::populateAssertToLLVMConversionPattern( |
| const LLVMTypeConverter &converter, RewritePatternSet &patterns, |
| bool abortOnFailure, SymbolTableCollection *symbolTables) { |
| patterns.add<AssertOpLowering>(converter, abortOnFailure, symbolTables); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pass Definition |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// A pass converting MLIR operations into the LLVM IR dialect. |
| struct ConvertControlFlowToLLVM |
| : public impl::ConvertControlFlowToLLVMPassBase<ConvertControlFlowToLLVM> { |
| |
| using Base::Base; |
| |
| /// Run the dialect converter on the module. |
| void runOnOperation() override { |
| MLIRContext *ctx = &getContext(); |
| LLVMConversionTarget target(*ctx); |
| // This pass lowers only CF dialect ops, but it also modifies block |
| // signatures inside other ops. These ops should be treated as legal. They |
| // are lowered by other passes. |
| target.markUnknownOpDynamicallyLegal([&](Operation *op) { |
| return op->getDialect() != |
| ctx->getLoadedDialect<cf::ControlFlowDialect>(); |
| }); |
| |
| LowerToLLVMOptions options(ctx); |
| if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) |
| options.overrideIndexBitwidth(indexBitwidth); |
| |
| LLVMTypeConverter converter(ctx, options); |
| RewritePatternSet patterns(ctx); |
| mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); |
| mlir::cf::populateAssertToLLVMConversionPattern(converter, patterns); |
| |
| if (failed(applyPartialConversion(getOperation(), target, |
| std::move(patterns)))) |
| signalPassFailure(); |
| } |
| }; |
| } // namespace |
| |
| //===----------------------------------------------------------------------===// |
| // ConvertToLLVMPatternInterface implementation |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// Implement the interface to convert MemRef to LLVM. |
| struct ControlFlowToLLVMDialectInterface |
| : public ConvertToLLVMPatternInterface { |
| using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; |
| void loadDependentDialects(MLIRContext *context) const final { |
| context->loadDialect<LLVM::LLVMDialect>(); |
| } |
| |
| /// Hook for derived dialect interface to provide conversion patterns |
| /// and mark dialect legal for the conversion target. |
| void populateConvertToLLVMConversionPatterns( |
| ConversionTarget &target, LLVMTypeConverter &typeConverter, |
| RewritePatternSet &patterns) const final { |
| mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, |
| patterns); |
| mlir::cf::populateAssertToLLVMConversionPattern(typeConverter, patterns); |
| } |
| }; |
| } // namespace |
| |
| void mlir::cf::registerConvertControlFlowToLLVMInterface( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, cf::ControlFlowDialect *dialect) { |
| dialect->addInterfaces<ControlFlowToLLVMDialectInterface>(); |
| }); |
| } |