| //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements miscellaneous inlining utilities. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Transforms/InliningUtils.h" |
| |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinOps.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/Interfaces/CallInterfaces.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/DebugLog.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include <optional> |
| |
| #define DEBUG_TYPE "inlining" |
| |
| using namespace mlir; |
| |
| /// Combine `callee` location with `caller` location to create a stack that |
| /// represents the call chain. |
| /// If `callee` location is a `CallSiteLoc`, indicating an existing stack of |
| /// locations, the `caller` location is appended to the end of it, extending |
| /// the chain. |
| /// Otherwise, a single `CallSiteLoc` is created, representing a direct call |
| /// from `caller` to `callee`. |
| static LocationAttr stackLocations(Location callee, Location caller) { |
| Location lastCallee = callee; |
| SmallVector<CallSiteLoc> calleeInliningStack; |
| while (auto nextCallSite = dyn_cast<CallSiteLoc>(lastCallee)) { |
| calleeInliningStack.push_back(nextCallSite); |
| lastCallee = nextCallSite.getCaller(); |
| } |
| |
| CallSiteLoc firstCallSite = CallSiteLoc::get(lastCallee, caller); |
| for (CallSiteLoc currentCallSite : reverse(calleeInliningStack)) |
| firstCallSite = |
| CallSiteLoc::get(currentCallSite.getCallee(), firstCallSite); |
| |
| return firstCallSite; |
| } |
| |
| /// Remap all locations reachable from the inlined blocks with CallSiteLoc |
| /// locations with the provided caller location. |
| static void |
| remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, |
| Location callerLoc) { |
| DenseMap<Location, LocationAttr> mappedLocations; |
| auto remapLoc = [&](Location loc) { |
| auto [it, inserted] = mappedLocations.try_emplace(loc); |
| // Only query the attribute uniquer once per callsite attribute. |
| if (inserted) { |
| LocationAttr newLoc = stackLocations(loc, callerLoc); |
| it->getSecond() = newLoc; |
| } |
| return it->second; |
| }; |
| |
| AttrTypeReplacer attrReplacer; |
| attrReplacer.addReplacement( |
| [&](LocationAttr loc) -> std::pair<LocationAttr, WalkResult> { |
| return {remapLoc(loc), WalkResult::skip()}; |
| }); |
| |
| for (Block &block : inlinedBlocks) { |
| for (BlockArgument &arg : block.getArguments()) |
| if (LocationAttr newLoc = remapLoc(arg.getLoc())) |
| arg.setLoc(newLoc); |
| |
| for (Operation &op : block) |
| attrReplacer.recursivelyReplaceElementsIn(&op, /*replaceAttrs=*/false, |
| /*replaceLocs=*/true); |
| } |
| } |
| |
| static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, |
| IRMapping &mapper) { |
| auto remapOperands = [&](Operation *op) { |
| for (auto &operand : op->getOpOperands()) |
| if (auto mappedOp = mapper.lookupOrNull(operand.get())) |
| operand.set(mappedOp); |
| }; |
| for (auto &block : inlinedBlocks) |
| block.walk(remapOperands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InlinerInterface |
| //===----------------------------------------------------------------------===// |
| |
| bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, |
| bool wouldBeCloned) const { |
| if (auto *handler = getInterfaceFor(call)) |
| return handler->isLegalToInline(call, callable, wouldBeCloned); |
| return false; |
| } |
| |
| bool InlinerInterface::isLegalToInline(Region *dest, Region *src, |
| bool wouldBeCloned, |
| IRMapping &valueMapping) const { |
| if (auto *handler = getInterfaceFor(dest->getParentOp())) |
| return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); |
| return false; |
| } |
| |
| bool InlinerInterface::isLegalToInline(Operation *op, Region *dest, |
| bool wouldBeCloned, |
| IRMapping &valueMapping) const { |
| if (auto *handler = getInterfaceFor(op)) |
| return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); |
| return false; |
| } |
| |
| bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { |
| auto *handler = getInterfaceFor(op); |
| return handler ? handler->shouldAnalyzeRecursively(op) : true; |
| } |
| |
| /// Handle the given inlined terminator by replacing it with a new operation |
| /// as necessary. |
| void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { |
| auto *handler = getInterfaceFor(op); |
| assert(handler && "expected valid dialect handler"); |
| handler->handleTerminator(op, newDest); |
| } |
| |
| /// Handle the given inlined terminator by replacing it with a new operation |
| /// as necessary. |
| void InlinerInterface::handleTerminator(Operation *op, |
| ValueRange valuesToRepl) const { |
| auto *handler = getInterfaceFor(op); |
| assert(handler && "expected valid dialect handler"); |
| handler->handleTerminator(op, valuesToRepl); |
| } |
| |
| /// Returns true if the inliner can assume a fast path of not creating a |
| /// new block, if there is only one block. |
| bool InlinerInterface::allowSingleBlockOptimization( |
| iterator_range<Region::iterator> inlinedBlocks) const { |
| if (inlinedBlocks.empty()) { |
| return true; |
| } |
| auto *handler = getInterfaceFor(inlinedBlocks.begin()->getParentOp()); |
| assert(handler && "expected valid dialect handler"); |
| return handler->allowSingleBlockOptimization(inlinedBlocks); |
| } |
| |
| Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call, |
| Operation *callable, Value argument, |
| DictionaryAttr argumentAttrs) const { |
| auto *handler = getInterfaceFor(callable); |
| assert(handler && "expected valid dialect handler"); |
| return handler->handleArgument(builder, call, callable, argument, |
| argumentAttrs); |
| } |
| |
| Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call, |
| Operation *callable, Value result, |
| DictionaryAttr resultAttrs) const { |
| auto *handler = getInterfaceFor(callable); |
| assert(handler && "expected valid dialect handler"); |
| return handler->handleResult(builder, call, callable, result, resultAttrs); |
| } |
| |
| void InlinerInterface::processInlinedCallBlocks( |
| Operation *call, iterator_range<Region::iterator> inlinedBlocks) const { |
| auto *handler = getInterfaceFor(call); |
| assert(handler && "expected valid dialect handler"); |
| handler->processInlinedCallBlocks(call, inlinedBlocks); |
| } |
| |
| /// Utility to check that all of the operations within 'src' can be inlined. |
| static bool isLegalToInline(InlinerInterface &interface, Region *src, |
| Region *insertRegion, bool shouldCloneInlinedRegion, |
| IRMapping &valueMapping) { |
| for (auto &block : *src) { |
| for (auto &op : block) { |
| // UnrealizedConversionCastOp is inlineable but cannot implement the |
| // inliner interface due to layering constraints. |
| if (isa<UnrealizedConversionCastOp>(op)) |
| continue; |
| |
| // Check this operation. |
| if (!interface.isLegalToInline(&op, insertRegion, |
| shouldCloneInlinedRegion, valueMapping)) { |
| LDBG() << "* Illegal to inline because of op: " |
| << OpWithFlags(&op, OpPrintingFlags().skipRegions()); |
| return false; |
| } |
| // Check any nested regions. |
| if (interface.shouldAnalyzeRecursively(&op) && |
| llvm::any_of(op.getRegions(), [&](Region ®ion) { |
| return !isLegalToInline(interface, ®ion, insertRegion, |
| shouldCloneInlinedRegion, valueMapping); |
| })) |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Inline Methods |
| //===----------------------------------------------------------------------===// |
| |
| static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder, |
| CallOpInterface call, |
| CallableOpInterface callable, |
| IRMapping &mapper) { |
| // Unpack the argument attributes if there are any. |
| SmallVector<DictionaryAttr> argAttrs( |
| callable.getCallableRegion()->getNumArguments(), |
| builder.getDictionaryAttr({})); |
| if (ArrayAttr arrayAttr = callable.getArgAttrsAttr()) { |
| assert(arrayAttr.size() == argAttrs.size()); |
| for (auto [idx, attr] : llvm::enumerate(arrayAttr)) |
| argAttrs[idx] = cast<DictionaryAttr>(attr); |
| } |
| |
| // Run the argument attribute handler for the given argument and attribute. |
| for (auto [blockArg, argAttr] : |
| llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) { |
| Value newArgument = interface.handleArgument( |
| builder, call, callable, mapper.lookup(blockArg), argAttr); |
| assert(newArgument.getType() == mapper.lookup(blockArg).getType() && |
| "expected the argument type to not change"); |
| |
| // Update the mapping to point the new argument returned by the handler. |
| mapper.map(blockArg, newArgument); |
| } |
| } |
| |
| static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder, |
| CallOpInterface call, CallableOpInterface callable, |
| ValueRange results) { |
| // Unpack the result attributes if there are any. |
| SmallVector<DictionaryAttr> resAttrs(results.size(), |
| builder.getDictionaryAttr({})); |
| if (ArrayAttr arrayAttr = callable.getResAttrsAttr()) { |
| assert(arrayAttr.size() == resAttrs.size()); |
| for (auto [idx, attr] : llvm::enumerate(arrayAttr)) |
| resAttrs[idx] = cast<DictionaryAttr>(attr); |
| } |
| |
| // Run the result attribute handler for the given result and attribute. |
| for (auto [result, resAttr] : llvm::zip(results, resAttrs)) { |
| // Store the original result users before running the handler. |
| DenseSet<Operation *> resultUsers(llvm::from_range, result.getUsers()); |
| |
| Value newResult = |
| interface.handleResult(builder, call, callable, result, resAttr); |
| assert(newResult.getType() == result.getType() && |
| "expected the result type to not change"); |
| |
| // Replace the result uses except for the ones introduce by the handler. |
| result.replaceUsesWithIf(newResult, [&](OpOperand &operand) { |
| return resultUsers.count(operand.getOwner()); |
| }); |
| } |
| } |
| |
| static LogicalResult inlineRegionImpl( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Block *inlineBlock, Block::iterator inlinePoint, |
| IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, |
| std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion, |
| CallOpInterface call = {}) { |
| assert(resultsToReplace.size() == regionResultTypes.size()); |
| // We expect the region to have at least one block. |
| if (src->empty()) |
| return failure(); |
| |
| // Check that all of the region arguments have been mapped. |
| auto *srcEntryBlock = &src->front(); |
| if (llvm::any_of(srcEntryBlock->getArguments(), |
| [&](BlockArgument arg) { return !mapper.contains(arg); })) |
| return failure(); |
| |
| // Check that the operations within the source region are valid to inline. |
| Region *insertRegion = inlineBlock->getParent(); |
| if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, |
| mapper) || |
| !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, |
| mapper)) |
| return failure(); |
| |
| // Run the argument attribute handler before inlining the callable region. |
| OpBuilder builder(inlineBlock, inlinePoint); |
| auto callable = dyn_cast<CallableOpInterface>(src->getParentOp()); |
| if (call && callable) |
| handleArgumentImpl(interface, builder, call, callable, mapper); |
| |
| // Clone the callee's source into the caller. |
| Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); |
| cloneCallback(builder, src, inlineBlock, postInsertBlock, mapper, |
| shouldCloneInlinedRegion); |
| |
| // Get the range of newly inserted blocks. |
| auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), |
| postInsertBlock->getIterator()); |
| Block *firstNewBlock = &*newBlocks.begin(); |
| |
| // Remap the locations of the inlined operations if a valid source location |
| // was provided. |
| if (inlineLoc && !llvm::isa<UnknownLoc>(*inlineLoc)) |
| remapInlinedLocations(newBlocks, *inlineLoc); |
| |
| // If the blocks were moved in-place, make sure to remap any necessary |
| // operands. |
| if (!shouldCloneInlinedRegion) |
| remapInlinedOperands(newBlocks, mapper); |
| |
| // Process the newly inlined blocks. |
| if (call) |
| interface.processInlinedCallBlocks(call, newBlocks); |
| interface.processInlinedBlocks(newBlocks); |
| |
| bool singleBlockFastPath = interface.allowSingleBlockOptimization(newBlocks); |
| |
| // Handle the case where only a single block was inlined. |
| if (singleBlockFastPath && llvm::hasSingleElement(newBlocks)) { |
| // Run the result attribute handler on the terminator operands. |
| Operation *firstBlockTerminator = firstNewBlock->getTerminator(); |
| builder.setInsertionPoint(firstBlockTerminator); |
| if (call && callable) |
| handleResultImpl(interface, builder, call, callable, |
| firstBlockTerminator->getOperands()); |
| |
| // Have the interface handle the terminator of this block. |
| interface.handleTerminator(firstBlockTerminator, resultsToReplace); |
| firstBlockTerminator->erase(); |
| |
| // Merge the post insert block into the cloned entry block. |
| firstNewBlock->getOperations().splice(firstNewBlock->end(), |
| postInsertBlock->getOperations()); |
| postInsertBlock->erase(); |
| } else { |
| // Otherwise, there were multiple blocks inlined. Add arguments to the post |
| // insertion block to represent the results to replace. |
| for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) { |
| resultToRepl.value().replaceAllUsesWith( |
| postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()], |
| resultToRepl.value().getLoc())); |
| } |
| |
| // Run the result attribute handler on the post insertion block arguments. |
| builder.setInsertionPointToStart(postInsertBlock); |
| if (call && callable) |
| handleResultImpl(interface, builder, call, callable, |
| postInsertBlock->getArguments()); |
| |
| /// Handle the terminators for each of the new blocks. |
| for (auto &newBlock : newBlocks) |
| interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); |
| } |
| |
| // Splice the instructions of the inlined entry block into the insert block. |
| inlineBlock->getOperations().splice(inlineBlock->end(), |
| firstNewBlock->getOperations()); |
| firstNewBlock->erase(); |
| return success(); |
| } |
| |
| static LogicalResult inlineRegionImpl( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Block *inlineBlock, Block::iterator inlinePoint, |
| ValueRange inlinedOperands, ValueRange resultsToReplace, |
| std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion, |
| CallOpInterface call = {}) { |
| // We expect the region to have at least one block. |
| if (src->empty()) |
| return failure(); |
| |
| auto *entryBlock = &src->front(); |
| if (inlinedOperands.size() != entryBlock->getNumArguments()) |
| return failure(); |
| |
| // Map the provided call operands to the arguments of the region. |
| IRMapping mapper; |
| for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { |
| // Verify that the types of the provided values match the function argument |
| // types. |
| BlockArgument regionArg = entryBlock->getArgument(i); |
| if (inlinedOperands[i].getType() != regionArg.getType()) |
| return failure(); |
| mapper.map(regionArg, inlinedOperands[i]); |
| } |
| |
| // Call into the main region inliner function. |
| return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, |
| inlinePoint, mapper, resultsToReplace, |
| resultsToReplace.getTypes(), inlineLoc, |
| shouldCloneInlinedRegion, call); |
| } |
| |
| LogicalResult mlir::inlineRegion( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Operation *inlinePoint, IRMapping &mapper, |
| ValueRange resultsToReplace, TypeRange regionResultTypes, |
| std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) { |
| return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), |
| ++inlinePoint->getIterator(), mapper, resultsToReplace, |
| regionResultTypes, inlineLoc, shouldCloneInlinedRegion); |
| } |
| |
| LogicalResult mlir::inlineRegion( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Block *inlineBlock, Block::iterator inlinePoint, |
| IRMapping &mapper, ValueRange resultsToReplace, TypeRange regionResultTypes, |
| std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) { |
| return inlineRegionImpl( |
| interface, cloneCallback, src, inlineBlock, inlinePoint, mapper, |
| resultsToReplace, regionResultTypes, inlineLoc, shouldCloneInlinedRegion); |
| } |
| |
| LogicalResult mlir::inlineRegion( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Operation *inlinePoint, ValueRange inlinedOperands, |
| ValueRange resultsToReplace, std::optional<Location> inlineLoc, |
| bool shouldCloneInlinedRegion) { |
| return inlineRegion(interface, cloneCallback, src, inlinePoint->getBlock(), |
| ++inlinePoint->getIterator(), inlinedOperands, |
| resultsToReplace, inlineLoc, shouldCloneInlinedRegion); |
| } |
| |
| LogicalResult mlir::inlineRegion( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| Region *src, Block *inlineBlock, Block::iterator inlinePoint, |
| ValueRange inlinedOperands, ValueRange resultsToReplace, |
| std::optional<Location> inlineLoc, bool shouldCloneInlinedRegion) { |
| return inlineRegionImpl(interface, cloneCallback, src, inlineBlock, |
| inlinePoint, inlinedOperands, resultsToReplace, |
| inlineLoc, shouldCloneInlinedRegion); |
| } |
| |
| /// Utility function used to generate a cast operation from the given interface, |
| /// or return nullptr if a cast could not be generated. |
| static Value materializeConversion(const DialectInlinerInterface *interface, |
| SmallVectorImpl<Operation *> &castOps, |
| OpBuilder &castBuilder, Value arg, Type type, |
| Location conversionLoc) { |
| if (!interface) |
| return nullptr; |
| |
| // Check to see if the interface for the call can materialize a conversion. |
| Operation *castOp = interface->materializeCallConversion(castBuilder, arg, |
| type, conversionLoc); |
| if (!castOp) |
| return nullptr; |
| castOps.push_back(castOp); |
| |
| // Ensure that the generated cast is correct. |
| assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && |
| castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); |
| return castOp->getResult(0); |
| } |
| |
| /// This function inlines a given region, 'src', of a callable operation, |
| /// 'callable', into the location defined by the given call operation. This |
| /// function returns failure if inlining is not possible, success otherwise. On |
| /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' |
| /// corresponds to whether the source region should be cloned into the 'call' or |
| /// spliced directly. |
| LogicalResult mlir::inlineCall( |
| InlinerInterface &interface, |
| function_ref<InlinerInterface::CloneCallbackSigTy> cloneCallback, |
| CallOpInterface call, CallableOpInterface callable, Region *src, |
| bool shouldCloneInlinedRegion) { |
| // We expect the region to have at least one block. |
| if (src->empty()) |
| return failure(); |
| auto *entryBlock = &src->front(); |
| ArrayRef<Type> callableResultTypes = callable.getResultTypes(); |
| |
| // Make sure that the number of arguments and results matchup between the call |
| // and the region. |
| SmallVector<Value, 8> callOperands(call.getArgOperands()); |
| SmallVector<Value, 8> callResults(call->getResults()); |
| if (callOperands.size() != entryBlock->getNumArguments() || |
| callResults.size() != callableResultTypes.size()) |
| return failure(); |
| |
| // A set of cast operations generated to matchup the signature of the region |
| // with the signature of the call. |
| SmallVector<Operation *, 4> castOps; |
| castOps.reserve(callOperands.size() + callResults.size()); |
| |
| // Functor used to cleanup generated state on failure. |
| auto cleanupState = [&] { |
| for (auto *op : castOps) { |
| op->getResult(0).replaceAllUsesWith(op->getOperand(0)); |
| op->erase(); |
| } |
| return failure(); |
| }; |
| |
| // Builder used for any conversion operations that need to be materialized. |
| OpBuilder castBuilder(call); |
| Location castLoc = call.getLoc(); |
| const auto *callInterface = interface.getInterfaceFor(call->getDialect()); |
| |
| // Map the provided call operands to the arguments of the region. |
| IRMapping mapper; |
| for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { |
| BlockArgument regionArg = entryBlock->getArgument(i); |
| Value operand = callOperands[i]; |
| |
| // If the call operand doesn't match the expected region argument, try to |
| // generate a cast. |
| Type regionArgType = regionArg.getType(); |
| if (operand.getType() != regionArgType) { |
| if (!(operand = materializeConversion(callInterface, castOps, castBuilder, |
| operand, regionArgType, castLoc))) |
| return cleanupState(); |
| } |
| mapper.map(regionArg, operand); |
| } |
| |
| // Ensure that the resultant values of the call match the callable. |
| castBuilder.setInsertionPointAfter(call); |
| for (unsigned i = 0, e = callResults.size(); i != e; ++i) { |
| Value callResult = callResults[i]; |
| if (callResult.getType() == callableResultTypes[i]) |
| continue; |
| |
| // Generate a conversion that will produce the original type, so that the IR |
| // is still valid after the original call gets replaced. |
| Value castResult = |
| materializeConversion(callInterface, castOps, castBuilder, callResult, |
| callResult.getType(), castLoc); |
| if (!castResult) |
| return cleanupState(); |
| callResult.replaceAllUsesWith(castResult); |
| castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); |
| } |
| |
| // Check that it is legal to inline the callable into the call. |
| if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) |
| return cleanupState(); |
| |
| // Attempt to inline the call. |
| if (failed(inlineRegionImpl(interface, cloneCallback, src, call->getBlock(), |
| ++call->getIterator(), mapper, callResults, |
| callableResultTypes, call.getLoc(), |
| shouldCloneInlinedRegion, call))) |
| return cleanupState(); |
| return success(); |
| } |