| //===- LoopInvariantCodeMotionUtils.cpp - LICM Utils ------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file contains the implementation of the core LICM algorithm. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" |
| |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/OperationSupport.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Interfaces/LoopLikeInterface.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Interfaces/SubsetOpInterface.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/DebugLog.h" |
| #include <queue> |
| |
| #define DEBUG_TYPE "licm" |
| |
| using namespace mlir; |
| |
| /// Checks whether the given op can be hoisted by checking that |
| /// - the op and none of its contained operations depend on values inside of the |
| /// loop (by means of calling definedOutside). |
| /// - the op has no side-effects. |
| static bool canBeHoisted(Operation *op, |
| function_ref<bool(OpOperand &)> condition) { |
| // Do not move terminators. |
| if (op->hasTrait<OpTrait::IsTerminator>()) |
| return false; |
| |
| // Walk the nested operations and check that all used values are either |
| // defined outside of the loop or in a nested region, but not at the level of |
| // the loop body. |
| auto walkFn = [&](Operation *child) { |
| for (OpOperand &operand : child->getOpOperands()) { |
| // Ignore values defined in a nested region. |
| if (op->isAncestor(operand.get().getParentRegion()->getParentOp())) |
| continue; |
| if (!condition(operand)) |
| return WalkResult::interrupt(); |
| } |
| return WalkResult::advance(); |
| }; |
| return !op->walk(walkFn).wasInterrupted(); |
| } |
| |
| static bool canBeHoisted(Operation *op, |
| function_ref<bool(Value)> definedOutside) { |
| return canBeHoisted( |
| op, [&](OpOperand &operand) { return definedOutside(operand.get()); }); |
| } |
| |
| size_t mlir::moveLoopInvariantCode( |
| ArrayRef<Region *> regions, |
| function_ref<bool(Value, Region *)> isDefinedOutsideRegion, |
| function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion, |
| function_ref<void(Operation *, Region *)> moveOutOfRegion) { |
| size_t numMoved = 0; |
| |
| for (Region *region : regions) { |
| LDBG() << "Original loop:\n" << *region->getParentOp(); |
| |
| std::queue<Operation *> worklist; |
| // Add top-level operations in the loop body to the worklist. |
| for (Operation &op : region->getOps()) |
| worklist.push(&op); |
| |
| auto definedOutside = [&](Value value) { |
| return isDefinedOutsideRegion(value, region); |
| }; |
| |
| while (!worklist.empty()) { |
| Operation *op = worklist.front(); |
| worklist.pop(); |
| // Skip ops that have already been moved. Check if the op can be hoisted. |
| if (op->getParentRegion() != region) |
| continue; |
| |
| LDBG() << "Checking op: " |
| << OpWithFlags(op, OpPrintingFlags().skipRegions()); |
| if (!shouldMoveOutOfRegion(op, region) || |
| !canBeHoisted(op, definedOutside)) |
| continue; |
| |
| LDBG() << "Moving loop-invariant op: " << *op; |
| moveOutOfRegion(op, region); |
| ++numMoved; |
| |
| // Since the op has been moved, we need to check its users within the |
| // top-level of the loop body. |
| for (Operation *user : op->getUsers()) |
| if (user->getParentRegion() == region) |
| worklist.push(user); |
| } |
| } |
| |
| return numMoved; |
| } |
| |
| size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { |
| return moveLoopInvariantCode( |
| loopLike.getLoopRegions(), |
| [&](Value value, Region *) { |
| return loopLike.isDefinedOutsideOfLoop(value); |
| }, |
| [&](Operation *op, Region *) { |
| return isMemoryEffectFree(op) && isSpeculatable(op); |
| }, |
| [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); |
| } |
| |
| namespace { |
| /// Helper data structure that keeps track of equivalent/disjoint subset ops. |
| class MatchingSubsets { |
| public: |
| /// Insert a subset op. |
| void insert(SubsetOpInterface op, bool collectHoistableOps = true) { |
| allSubsetOps.push_back(op); |
| if (!collectHoistableOps) |
| return; |
| if (auto extractionOp = |
| dyn_cast<SubsetExtractionOpInterface>(op.getOperation())) |
| insertExtractionOp(extractionOp); |
| if (auto insertionOp = |
| dyn_cast<SubsetInsertionOpInterface>(op.getOperation())) |
| insertInsertionOp(insertionOp); |
| } |
| |
| /// Return a range of matching extraction-insertion subset ops. If there is no |
| /// matching extraction/insertion op, the respective value is empty. Ops are |
| /// skipped if there are other subset ops that are not guaranteed to operate |
| /// on disjoint subsets. |
| auto getHoistableSubsetOps() { |
| return llvm::make_filter_range( |
| llvm::zip(extractions, insertions), [&](auto pair) { |
| auto [extractionOp, insertionOp] = pair; |
| // Hoist only if the extracted and inserted values have the same type. |
| if (extractionOp && insertionOp && |
| extractionOp->getResult(0).getType() != |
| insertionOp.getSourceOperand().get().getType()) |
| return false; |
| // Hoist only if there are no conflicting subset ops. |
| return allDisjoint(extractionOp, insertionOp); |
| }); |
| } |
| |
| /// Populate subset ops starting from the given region iter_arg. Return |
| /// "failure" if non-subset ops are found along the path to the loop yielding |
| /// op or if there is no single path to the tied yielded operand. If |
| /// `collectHoistableOps` is set to "false", subset ops are gathered |
| /// throughout the traversal, but not enumerated by `getHoistableSubsetOps`. |
| LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
| BlockArgument iterArg, |
| bool collectHoistableOps = true); |
| |
| private: |
| /// Helper function for equivalence of tensor values. Since only insertion |
| /// subset ops (that are also destination style ops) are followed when |
| /// traversing the SSA use-def chain, all tensor values are equivalent. |
| static bool isEquivalent(Value v1, Value v2) { return true; } |
| |
| /// Return "true" if the subsets of the given extraction and insertion ops |
| /// are operating disjoint from the subsets that all other known subset ops |
| /// are operating on. |
| bool allDisjoint(SubsetExtractionOpInterface extractionOp, |
| SubsetInsertionOpInterface insertionOp) const { |
| for (SubsetOpInterface other : allSubsetOps) { |
| if (other == extractionOp || other == insertionOp) |
| continue; |
| if (extractionOp && |
| !other.operatesOnDisjointSubset(extractionOp, isEquivalent)) |
| return false; |
| if (insertionOp && |
| !other.operatesOnDisjointSubset(insertionOp, isEquivalent)) |
| return false; |
| } |
| return true; |
| } |
| |
| /// Insert a subset extraction op. If the subset is equivalent to an existing |
| /// subset insertion op, pair them up. (If there is already a paired up subset |
| /// extraction op, overwrite the subset extraction op.) |
| void insertExtractionOp(SubsetExtractionOpInterface extractionOp) { |
| for (auto it : llvm::enumerate(insertions)) { |
| if (!it.value()) |
| continue; |
| auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
| if (other.operatesOnEquivalentSubset(extractionOp, isEquivalent)) { |
| extractions[it.index()] = extractionOp; |
| return; |
| } |
| } |
| // There is no known equivalent insertion op. Create a new entry. |
| extractions.push_back(extractionOp); |
| insertions.push_back({}); |
| } |
| |
| /// Insert a subset insertion op. If the subset is equivalent to an existing |
| /// subset extraction op, pair them up. (If there is already a paired up |
| /// subset insertion op, overwrite the subset insertion op.) |
| void insertInsertionOp(SubsetInsertionOpInterface insertionOp) { |
| for (auto it : llvm::enumerate(extractions)) { |
| if (!it.value()) |
| continue; |
| auto other = cast<SubsetOpInterface>(it.value().getOperation()); |
| if (other.operatesOnEquivalentSubset(insertionOp, isEquivalent)) { |
| insertions[it.index()] = insertionOp; |
| return; |
| } |
| } |
| // There is no known equivalent extraction op. Create a new entry. |
| extractions.push_back({}); |
| insertions.push_back(insertionOp); |
| } |
| |
| SmallVector<SubsetExtractionOpInterface> extractions; |
| SmallVector<SubsetInsertionOpInterface> insertions; |
| SmallVector<SubsetOpInterface> allSubsetOps; |
| }; |
| } // namespace |
| |
| /// If the given value has a single use by an op that is a terminator, return |
| /// that use. Otherwise, return nullptr. |
| static OpOperand *getSingleTerminatorUse(Value value) { |
| if (!value.hasOneUse()) |
| return nullptr; |
| OpOperand &use = *value.getUses().begin(); |
| if (use.getOwner()->hasTrait<OpTrait::IsTerminator>()) |
| return &use; |
| return nullptr; |
| } |
| |
| LogicalResult |
| MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike, |
| BlockArgument iterArg, |
| bool collectHoistableOps) { |
| assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); |
| Value value = iterArg; |
| |
| // Traverse use-def chain. Subset ops can be hoisted only if all ops along the |
| // use-def chain starting from the region iter_arg are subset extraction or |
| // subset insertion ops. The chain must terminate at the corresponding yield |
| // operand (e.g., no swapping of iter_args). |
| OpOperand *yieldedOperand = nullptr; |
| // Iterate until the single use of the current SSA value is a terminator, |
| // which is expected to be the yielding operation of the loop. |
| while (!(yieldedOperand = getSingleTerminatorUse(value))) { |
| Value nextValue = {}; |
| |
| for (OpOperand &use : value.getUses()) { |
| if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) { |
| // Subset ops in nested loops are collected to check if there are only |
| // disjoint subset ops, but such subset ops are not subject to hoisting. |
| // To hoist subset ops from nested loops, the hoisting transformation |
| // should be run on the nested loop. |
| auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use); |
| if (!nestedIterArg) |
| return failure(); |
| // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA |
| // use-def chain starting at `nestedIterArg` and terminating in the |
| // tied, yielding operand. |
| if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg, |
| /*collectHoistableOps=*/false))) |
| return failure(); |
| nextValue = nestedLoop.getTiedLoopResult(&use); |
| continue; |
| } |
| |
| auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner()); |
| if (!subsetOp) |
| return failure(); |
| insert(subsetOp); |
| |
| if (auto insertionOp = |
| dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) { |
| // Current implementation expects that the insertionOp implement |
| // the DestinationStyleOpInterface and with pure tensor semantics |
| // as well. Abort if that is not the case. |
| auto dstOp = dyn_cast<DestinationStyleOpInterface>(use.getOwner()); |
| if (!dstOp || !dstOp.hasPureTensorSemantics()) |
| return failure(); |
| |
| // The value must be used as a destination. (In case of a source, the |
| // entire tensor would be read, which would prevent any hoisting.) |
| if (&use != &insertionOp.getDestinationOperand()) |
| return failure(); |
| // There must be a single use-def chain from the region iter_arg to the |
| // terminator. I.e., only one insertion op. Branches are not supported. |
| if (nextValue) |
| return failure(); |
| nextValue = insertionOp.getUpdatedDestination(); |
| } |
| } |
| |
| // Nothing can be hoisted if the chain does not continue with loop yielding |
| // op or a subset insertion op. |
| if (!nextValue) |
| return failure(); |
| value = nextValue; |
| } |
| |
| // Hoist only if the SSA use-def chain ends in the yielding terminator of the |
| // loop and the yielded value is the `idx`-th operand. (I.e., there is no |
| // swapping yield.) |
| if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand) |
| return failure(); |
| |
| return success(); |
| } |
| |
| /// Hoist all subset ops that operate on the idx-th region iter_arg of the given |
| /// loop-like op and index into loop-invariant subset locations. Return the |
| /// newly created loop op (that has extra iter_args) or the original loop op if |
| /// nothing was hoisted. |
| static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, |
| LoopLikeOpInterface loopLike, |
| BlockArgument iterArg) { |
| assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); |
| BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg); |
| int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); |
| MatchingSubsets subsets; |
| if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg))) |
| return loopLike; |
| |
| // Hoist all matching extraction-insertion pairs one-by-one. |
| for (auto it : subsets.getHoistableSubsetOps()) { |
| auto extractionOp = std::get<0>(it); |
| auto insertionOp = std::get<1>(it); |
| |
| // Ops cannot be hoisted if they depend on loop-variant values. |
| if (extractionOp) { |
| if (!canBeHoisted(extractionOp, [&](OpOperand &operand) { |
| return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
| &operand == &extractionOp.getSourceOperand(); |
| })) |
| extractionOp = {}; |
| } |
| if (insertionOp) { |
| if (!canBeHoisted(insertionOp, [&](OpOperand &operand) { |
| return loopLike.isDefinedOutsideOfLoop(operand.get()) || |
| &operand == &insertionOp.getSourceOperand() || |
| &operand == &insertionOp.getDestinationOperand(); |
| })) |
| insertionOp = {}; |
| } |
| |
| // Only hoist extraction-insertion pairs for now. Standalone extractions/ |
| // insertions that are loop-invariant could be hoisted, but there may be |
| // easier ways to canonicalize the IR. |
| if (extractionOp && insertionOp) { |
| // Create a new loop with an additional iter_arg. |
| NewYieldValuesFn newYieldValuesFn = |
| [&](OpBuilder &b, Location loc, |
| ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> { |
| return {insertionOp.getSourceOperand().get()}; |
| }; |
| FailureOr<LoopLikeOpInterface> newLoop = |
| loopLike.replaceWithAdditionalYields( |
| rewriter, extractionOp.getResult(), |
| /*replaceInitOperandUsesInLoop=*/true, newYieldValuesFn); |
| if (failed(newLoop)) |
| return loopLike; |
| loopLike = *newLoop; |
| |
| // Hoist the extraction/insertion ops. |
| iterArg = loopLike.getRegionIterArgs()[iterArgIdx]; |
| OpResult loopResult = loopLike.getTiedLoopResult(iterArg); |
| OpResult newLoopResult = loopLike.getLoopResults()->back(); |
| rewriter.moveOpBefore(extractionOp, loopLike); |
| rewriter.moveOpAfter(insertionOp, loopLike); |
| rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(), |
| insertionOp.getDestinationOperand().get()); |
| extractionOp.getSourceOperand().set( |
| loopLike.getTiedLoopInit(iterArg)->get()); |
| rewriter.replaceAllUsesWith(loopResult, |
| insertionOp.getUpdatedDestination()); |
| insertionOp.getSourceOperand().set(newLoopResult); |
| insertionOp.getDestinationOperand().set(loopResult); |
| } |
| } |
| |
| return loopLike; |
| } |
| |
| LoopLikeOpInterface |
| mlir::hoistLoopInvariantSubsets(RewriterBase &rewriter, |
| LoopLikeOpInterface loopLike) { |
| // Note: As subset ops are getting hoisted, the number of region iter_args |
| // increases. This can enable further hoisting opportunities on the new |
| // iter_args. |
| for (int64_t i = 0; |
| i < static_cast<int64_t>(loopLike.getRegionIterArgs().size()); ++i) { |
| loopLike = hoistSubsetAtIterArg(rewriter, loopLike, |
| loopLike.getRegionIterArgs()[i]); |
| } |
| return loopLike; |
| } |