| //===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a commutativity utility pattern and a function to |
| // populate this pattern. The function is intended to be used inside passes to |
| // simplify the matching of commutative operations by fixing the order of their |
| // operands. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Transforms/CommutativityUtils.h" |
| |
| #include <queue> |
| |
| using namespace mlir; |
| |
| /// The possible "types" of ancestors. Here, an ancestor is an op or a block |
| /// argument present in the backward slice of a value. |
| enum AncestorType { |
| /// Pertains to a block argument. |
| BLOCK_ARGUMENT, |
| |
| /// Pertains to a non-constant-like op. |
| NON_CONSTANT_OP, |
| |
| /// Pertains to a constant-like op. |
| CONSTANT_OP |
| }; |
| |
| /// Stores the "key" associated with an ancestor. |
| struct AncestorKey { |
| /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on |
| /// the ancestor. |
| AncestorType type; |
| |
| /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or |
| /// `CONSTANT_OP`. Else, holds "". |
| StringRef opName; |
| |
| /// Constructor for `AncestorKey`. |
| AncestorKey(Operation *op) { |
| if (!op) { |
| type = BLOCK_ARGUMENT; |
| } else { |
| type = |
| op->hasTrait<OpTrait::ConstantLike>() ? CONSTANT_OP : NON_CONSTANT_OP; |
| opName = op->getName().getStringRef(); |
| } |
| } |
| |
| /// Overloaded operator `<` for `AncestorKey`. |
| /// |
| /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those |
| /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in |
| /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller |
| /// ones are the ones with smaller op names (lexicographically). |
| /// |
| /// TODO: Include other information like attributes, value type, etc., to |
| /// enhance this comparison. For example, currently this comparison doesn't |
| /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and |
| /// `addi (in i64)`. Such an enhancement should only be done if the need |
| /// arises. |
| bool operator<(const AncestorKey &key) const { |
| return std::tie(type, opName) < std::tie(key.type, key.opName); |
| } |
| }; |
| |
| /// Stores a commutative operand along with its BFS traversal information. |
| struct CommutativeOperand { |
| /// Stores the operand. |
| Value operand; |
| |
| /// Stores the queue of ancestors of the operand's BFS traversal at a |
| /// particular point in time. |
| std::queue<Operation *> ancestorQueue; |
| |
| /// Stores the list of ancestors that have been visited by the BFS traversal |
| /// at a particular point in time. |
| DenseSet<Operation *> visitedAncestors; |
| |
| /// Stores the operand's "key". This "key" is defined as a list of the |
| /// "AncestorKeys" associated with the ancestors of this operand, in a |
| /// breadth-first order. |
| /// |
| /// So, if an operand, say `A`, was produced as follows: |
| /// |
| /// `<block argument>` `<block argument>` |
| /// \ / |
| /// \ / |
| /// `arith.subi` `arith.constant` |
| /// \ / |
| /// `arith.addi` |
| /// | |
| /// returns `A` |
| /// |
| /// Then, the ancestors of `A`, in the breadth-first order are: |
| /// `arith.addi`, `arith.subi`, `arith.constant`, `<block argument>`, and |
| /// `<block argument>`. |
| /// |
| /// Thus, the "key" associated with operand `A` is: |
| /// { |
| /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, |
| /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, |
| /// {type: `CONSTANT_OP`, opName: "arith.constant"}, |
| /// {type: `BLOCK_ARGUMENT`, opName: ""}, |
| /// {type: `BLOCK_ARGUMENT`, opName: ""} |
| /// } |
| SmallVector<AncestorKey, 4> key; |
| |
| /// Push an ancestor into the operand's BFS information structure. This |
| /// entails it being pushed into the queue (always) and inserted into the |
| /// "visited ancestors" list (iff it is an op rather than a block argument). |
| void pushAncestor(Operation *op) { |
| ancestorQueue.push(op); |
| if (op) |
| visitedAncestors.insert(op); |
| } |
| |
| /// Refresh the key. |
| /// |
| /// Refreshing a key entails making it up-to-date with the operand's BFS |
| /// traversal that has happened till that point in time, i.e, appending the |
| /// existing key with the front ancestor's "AncestorKey". Note that a key |
| /// directly reflects the BFS and thus needs to be refreshed during the |
| /// progression of the traversal. |
| void refreshKey() { |
| if (ancestorQueue.empty()) |
| return; |
| |
| Operation *frontAncestor = ancestorQueue.front(); |
| AncestorKey frontAncestorKey(frontAncestor); |
| key.push_back(frontAncestorKey); |
| } |
| |
| /// Pop the front ancestor, if any, from the queue and then push its adjacent |
| /// unvisited ancestors, if any, to the queue (this is the main body of the |
| /// BFS algorithm). |
| void popFrontAndPushAdjacentUnvisitedAncestors() { |
| if (ancestorQueue.empty()) |
| return; |
| Operation *frontAncestor = ancestorQueue.front(); |
| ancestorQueue.pop(); |
| if (!frontAncestor) |
| return; |
| for (Value operand : frontAncestor->getOperands()) { |
| Operation *operandDefOp = operand.getDefiningOp(); |
| if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) |
| pushAncestor(operandDefOp); |
| } |
| } |
| }; |
| |
| /// Sorts the operands of `op` in ascending order of the "key" associated with |
| /// each operand iff `op` is commutative. This is a stable sort. |
| /// |
| /// After the application of this pattern, since the commutative operands now |
| /// have a deterministic order in which they occur in an op, the matching of |
| /// large DAGs becomes much simpler, i.e., requires much less number of checks |
| /// to be written by a user in her/his pattern matching function. |
| /// |
| /// Some examples of such a sorting: |
| /// |
| /// Assume that the sorting is being applied to `foo.commutative`, which is a |
| /// commutative op. |
| /// |
| /// Example 1: |
| /// |
| /// %1 = foo.const 0 |
| /// %2 = foo.mul <block argument>, <block argument> |
| /// %3 = foo.commutative %1, %2 |
| /// |
| /// Here, |
| /// 1. The key associated with %1 is: |
| /// `{ |
| /// {CONSTANT_OP, "foo.const"} |
| /// }` |
| /// 2. The key associated with %2 is: |
| /// `{ |
| /// {NON_CONSTANT_OP, "foo.mul"}, |
| /// {BLOCK_ARGUMENT, ""}, |
| /// {BLOCK_ARGUMENT, ""} |
| /// }` |
| /// |
| /// The key of %2 < the key of %1 |
| /// Thus, the sorted `foo.commutative` is: |
| /// %3 = foo.commutative %2, %1 |
| /// |
| /// Example 2: |
| /// |
| /// %1 = foo.const 0 |
| /// %2 = foo.mul <block argument>, <block argument> |
| /// %3 = foo.mul %2, %1 |
| /// %4 = foo.add %2, %1 |
| /// %5 = foo.commutative %1, %2, %3, %4 |
| /// |
| /// Here, |
| /// 1. The key associated with %1 is: |
| /// `{ |
| /// {CONSTANT_OP, "foo.const"} |
| /// }` |
| /// 2. The key associated with %2 is: |
| /// `{ |
| /// {NON_CONSTANT_OP, "foo.mul"}, |
| /// {BLOCK_ARGUMENT, ""} |
| /// }` |
| /// 3. The key associated with %3 is: |
| /// `{ |
| /// {NON_CONSTANT_OP, "foo.mul"}, |
| /// {NON_CONSTANT_OP, "foo.mul"}, |
| /// {CONSTANT_OP, "foo.const"}, |
| /// {BLOCK_ARGUMENT, ""}, |
| /// {BLOCK_ARGUMENT, ""} |
| /// }` |
| /// 4. The key associated with %4 is: |
| /// `{ |
| /// {NON_CONSTANT_OP, "foo.add"}, |
| /// {NON_CONSTANT_OP, "foo.mul"}, |
| /// {CONSTANT_OP, "foo.const"}, |
| /// {BLOCK_ARGUMENT, ""}, |
| /// {BLOCK_ARGUMENT, ""} |
| /// }` |
| /// |
| /// Thus, the sorted `foo.commutative` is: |
| /// %5 = foo.commutative %4, %3, %2, %1 |
| class SortCommutativeOperands : public RewritePattern { |
| public: |
| SortCommutativeOperands(MLIRContext *context) |
| : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| // Custom comparator for two commutative operands, which returns true iff |
| // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, |
| // i.e., |
| // 1. In the first unequal pair of corresponding AncestorKeys, the |
| // AncestorKey in `constCommOperandA` is smaller, or, |
| // 2. Both the AncestorKeys in every pair are the same and the size of |
| // `constCommOperandA`'s "key" is smaller. |
| auto commutativeOperandComparator = |
| [](const std::unique_ptr<CommutativeOperand> &constCommOperandA, |
| const std::unique_ptr<CommutativeOperand> &constCommOperandB) { |
| if (constCommOperandA->operand == constCommOperandB->operand) |
| return false; |
| |
| auto &commOperandA = |
| const_cast<std::unique_ptr<CommutativeOperand> &>( |
| constCommOperandA); |
| auto &commOperandB = |
| const_cast<std::unique_ptr<CommutativeOperand> &>( |
| constCommOperandB); |
| |
| // Iteratively perform the BFS's of both operands until an order among |
| // them can be determined. |
| unsigned keyIndex = 0; |
| while (true) { |
| if (commOperandA->key.size() <= keyIndex) { |
| if (commOperandA->ancestorQueue.empty()) |
| return true; |
| commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); |
| commOperandA->refreshKey(); |
| } |
| if (commOperandB->key.size() <= keyIndex) { |
| if (commOperandB->ancestorQueue.empty()) |
| return false; |
| commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); |
| commOperandB->refreshKey(); |
| } |
| if (commOperandA->ancestorQueue.empty() || |
| commOperandB->ancestorQueue.empty()) |
| return commOperandA->key.size() < commOperandB->key.size(); |
| if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) |
| return true; |
| if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) |
| return false; |
| keyIndex++; |
| } |
| }; |
| |
| // If `op` is not commutative, do nothing. |
| if (!op->hasTrait<OpTrait::IsCommutative>()) |
| return failure(); |
| |
| // Populate the list of commutative operands. |
| SmallVector<Value, 2> operands = op->getOperands(); |
| SmallVector<std::unique_ptr<CommutativeOperand>, 2> commOperands; |
| for (Value operand : operands) { |
| std::unique_ptr<CommutativeOperand> commOperand = |
| std::make_unique<CommutativeOperand>(); |
| commOperand->operand = operand; |
| commOperand->pushAncestor(operand.getDefiningOp()); |
| commOperand->refreshKey(); |
| commOperands.push_back(std::move(commOperand)); |
| } |
| |
| // Sort the operands. |
| llvm::stable_sort(commOperands, commutativeOperandComparator); |
| SmallVector<Value, 2> sortedOperands; |
| for (const std::unique_ptr<CommutativeOperand> &commOperand : commOperands) |
| sortedOperands.push_back(commOperand->operand); |
| if (sortedOperands == operands) |
| return failure(); |
| rewriter.modifyOpInPlace(op, [&] { op->setOperands(sortedOperands); }); |
| return success(); |
| } |
| }; |
| |
| void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { |
| patterns.add<SortCommutativeOperands>(patterns.getContext()); |
| } |