| //===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements pattern rewrites to lower sequences of |
| // `vector.to_elements` and `vector.from_elements` operations into a tree of |
| // `vector.shuffle` operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| #include "mlir/Dialect/Vector/Transforms/Passes.h" |
| #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/MathExtras.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| namespace mlir { |
| namespace vector { |
| |
| #define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE |
| #include "mlir/Dialect/Vector/Transforms/Passes.h.inc" |
| |
| } // namespace vector |
| } // namespace mlir |
| |
| #define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree" |
| |
| using namespace mlir; |
| using namespace mlir::vector; |
| |
| namespace { |
| |
| // Indentation unit for debug output formatting. |
| [[maybe_unused]] constexpr unsigned kIndScale = 2; |
| |
| /// Represents a closed interval of elements (e.g., [0, 7] = 8 elements). |
| using Interval = std::pair<unsigned, unsigned>; |
| // Sentinel value for uninitialized intervals. |
| constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max(); |
| |
| /// The VectorShuffleTreeBuilder builds a balanced binary tree of |
| /// `vector.shuffle` operations from one or more `vector.to_elements` |
| /// operations feeding a single `vector.from_elements` operation. |
| /// |
| /// The implementation generates hardware-agnostic `vector.shuffle` operations |
| /// that minimize both the number of shuffle operations and the length of |
| /// intermediate vectors (to the extent possible). The tree has the |
| /// following properties: |
| /// |
| /// 1. Vectors are shuffled in pairs by order of appearance in |
| /// the `vector.from_elements` operand list. |
| /// 2. Each vector at each level is used only once. |
| /// 3. The number of levels in the tree is: |
| /// 1 (input vectors) + ceil(max(1,log2(# `vector.to_elements` ops))). |
| /// 4. Vectors at each level of the tree have the same vector length. |
| /// 5. Vector positions that do not need to be shuffled are represented with |
| /// poison in the shuffle mask. |
| /// |
| /// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>: |
| /// |
| /// %0:4 = vector.to_elements %a : vector<4xf32> |
| /// %1:4 = vector.to_elements %b : vector<4xf32> |
| /// %2:4 = vector.to_elements %c : vector<4xf32> |
| /// %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, |
| /// %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 |
| /// : vector<12xf32> |
| /// => |
| /// |
| /// %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] |
| /// : vector<4xf32>, vector<4xf32> |
| /// %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] |
| /// : vector<4xf32>, vector<4xf32> |
| /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5, |
| /// 6, 7, 8, 9, 10, 11] |
| /// : vector<8xf32>, vector<8xf32> |
| /// |
| /// Comments: |
| /// * The shuffle tree has three levels: |
| /// - Level 0 = (%a, %b, %c, %c) |
| /// - Level 1 = (%shuffle0, %shuffle1) |
| /// - Level 2 = (%result) |
| /// * `%a` and `%b` are shuffled first because they appear first in the |
| /// `vector.from_elements` operand list (`%0#0` and `%1#0`). |
| /// * `%c` is shuffled with itself because the number of |
| /// `vector.from_elements` operands is odd. |
| /// * The vector length for level 1 and level 2 are 8 and 16, respectively. |
| /// * `%shuffle1` uses poison values to match the vector length of its |
| /// tree level (8). |
| /// |
| /// |
| /// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// => |
| /// |
| /// %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] |
| /// : vector<5xf32>, vector<5xf32> |
| /// %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] |
| /// : vector<5xf32>, vector<5xf32> |
| /// %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| /// : vector<8xf32>, vector<8xf32> |
| /// |
| /// Comments: |
| /// * `%c` and `%b` are shuffled first because they appear first in the |
| /// `vector.from_elements` operand list (`%2#2` and `%1#1`). |
| /// * `%a` is shuffled with itself because the number of |
| /// `vector.from_elements` operands is odd. |
| /// * The vector length for level 1 and level 2 are 8 and 9, respectively. |
| /// * `%shuffle0` uses poison values to mark unused vector positions and |
| /// match the vector length of its tree level (8). |
| /// |
| /// TODO: Implement mask compression to reduce the number of intermediate poison |
| /// values. |
| class VectorShuffleTreeBuilder { |
| public: |
| VectorShuffleTreeBuilder() = delete; |
| VectorShuffleTreeBuilder(FromElementsOp fromElemOp, |
| ArrayRef<ToElementsOp> toElemDefs); |
| |
| /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence |
| /// and compute the shuffle tree configuration. This method does not generate |
| /// any IR. |
| LogicalResult computeShuffleTree(); |
| |
| /// Materialize the shuffle tree configuration computed by |
| /// `computeShuffleTree` in the IR. |
| Value generateShuffleTree(PatternRewriter &rewriter); |
| |
| private: |
| // IR input information. |
| FromElementsOp fromElemsOp; |
| SmallVector<ToElementsOp> toElemsDefs; |
| |
| // Shuffle tree configuration. |
| unsigned numLevels; |
| SmallVector<unsigned> vectorSizePerLevel; |
| /// Holds the range of positions each vector in the tree contributes to in the |
| /// final output vector. |
| SmallVector<SmallVector<Interval>> intervalsPerLevel; |
| |
| // Utility methods to compute the shuffle tree configuration. |
| void computeShuffleTreeIntervals(); |
| void computeShuffleTreeVectorSizes(); |
| |
| /// Dump the shuffle tree configuration. |
| void dump(); |
| }; |
| |
| VectorShuffleTreeBuilder::VectorShuffleTreeBuilder( |
| FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs) |
| : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) { |
| assert(fromElemsOp && "from_elements op is required"); |
| assert(!toElemsDefs.empty() && "At least one to_elements op is required"); |
| } |
| |
| /// Duplicate the last operation, value or interval if the total number of them |
| /// is odd. This is useful to simplify the shuffle tree algorithm given that |
| /// vectors are shuffled in pairs and duplication would lead to the last shuffle |
| /// to have a single (duplicated) input vector. |
| template <typename T> |
| static void duplicateLastIfOdd(SmallVectorImpl<T> &values) { |
| if (values.size() % 2 != 0) |
| values.push_back(values.back()); |
| } |
| |
| // ===---------------------------------------------------------------------===// |
| // Shuffle Tree Analysis Utilities. |
| // ===---------------------------------------------------------------------===// |
| |
| /// Compute the intervals for all the vectors in the shuffle tree. The interval |
| /// of a vector is the range of positions that the vector contributes to in the |
| /// final output vector. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// |
| /// The shuffle tree has 3 levels. Level 0 has 4 vectors (%2, %1, %0, %0, the |
| /// last one is duplicated to make the number of inputs even) so we compute the |
| /// interval for each vector: |
| /// |
| /// * intervalsPerLevel[0][0] = interval(%2) = [0,6] |
| /// * intervalsPerLevel[0][1] = interval(%1) = [1,7] |
| /// * intervalsPerLevel[0][2] = interval(%0) = [2,8] |
| /// * intervalsPerLevel[0][3] = interval(%0) = [2,8] |
| /// |
| /// Level 1 has 2 vectors, resulting from the shuffling of %2 + %1 and %0 + %0 |
| /// so we compute the intervals for each vector at level 1 as: |
| /// * intervalsPerLevel[1][0] = intervalsPerLevel[0][0] U |
| /// intervalsPerLevel[0][1] = [0,7] |
| /// * intervalsPerLevel[1][1] = intervalsPerLevel[0][2] U |
| /// intervalsPerLevel[0][3] = [2,8] |
| /// |
| /// Level 2 is the last level and only contains the output vector so the |
| /// interval should be the whole output vector: |
| /// * intervalsPerLevel[2][0] = intervalsPerLevel[1][0] U |
| /// intervalsPerLevel[1][1] = [0,8] |
| /// |
| void VectorShuffleTreeBuilder::computeShuffleTreeIntervals() { |
| // Map `vector.to_elements` ops to their ordinal position in the |
| // `vector.from_elements` operand list. Make sure duplicated |
| // `vector.to_elements` ops are mapped to the its first occurrence. |
| DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal; |
| for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs)) |
| toElemsToInputOrdinal.insert({toElemsOp, idx}); |
| |
| // Compute intervals for each vector in the shuffle tree. The first |
| // level computation is special-cased to keep the implementation simpler. |
| |
| SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(), |
| {kMaxUnsigned, kMaxUnsigned}); |
| |
| for (const auto &[idx, element] : |
| llvm::enumerate(fromElemsOp.getElements())) { |
| auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp()); |
| unsigned inputIdx = toElemsToInputOrdinal[toElemsOp]; |
| Interval ¤tInterval = firstLevelIntervals[inputIdx]; |
| |
| // Set lower bound to the first occurrence of the `vector.to_elements`. |
| if (currentInterval.first == kMaxUnsigned) |
| currentInterval.first = idx; |
| |
| // Set upper bound to the last occurrence of the `vector.to_elements`. |
| currentInterval.second = idx; |
| } |
| |
| duplicateLastIfOdd(toElemsDefs); |
| duplicateLastIfOdd(firstLevelIntervals); |
| intervalsPerLevel.push_back(std::move(firstLevelIntervals)); |
| |
| // Compute intervals for the remaining levels. |
| for (unsigned level = 1; level < numLevels; ++level) { |
| bool isLastLevel = level == numLevels - 1; |
| const auto &prevLevelIntervals = intervalsPerLevel[level - 1]; |
| SmallVector<Interval> currentLevelIntervals( |
| llvm::divideCeil(prevLevelIntervals.size(), 2), |
| {kMaxUnsigned, kMaxUnsigned}); |
| |
| size_t currentNumLevels = currentLevelIntervals.size(); |
| for (size_t inputIdx = 0; inputIdx < currentNumLevels; ++inputIdx) { |
| auto &interval = currentLevelIntervals[inputIdx]; |
| const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2]; |
| const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1]; |
| |
| // The interval of a vector at the current level is the union of the |
| // intervals of the two vectors from the previous level being shuffled at |
| // this level. |
| interval.first = prevLhsInterval.first; |
| interval.second = |
| std::max(prevLhsInterval.second, prevRhsInterval.second); |
| } |
| |
| // Duplicate the last interval if the number of intervals is odd, except for |
| // the last level as it only contains the output vector, which doesn't have |
| // to be shuffled. |
| if (!isLastLevel) |
| duplicateLastIfOdd(currentLevelIntervals); |
| |
| intervalsPerLevel.push_back(std::move(currentLevelIntervals)); |
| } |
| } |
| |
| /// Compute the uniform vector size for each level of the shuffle tree, given |
| /// the intervals of the vectors at each level. The vector size of a level is |
| /// the size of the widest interval at that level. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// Intervals: |
| /// * Level 0: [0,6], [1,7], [2,8], [2,8] |
| /// * Level 1: [0,7], [2,8] |
| /// * Level 2: [0,8] |
| /// |
| /// Vector sizes: |
| /// * Level 0: Arbitrary sizes from input vectors. |
| /// * Level 1: max(size_of([0,7]) = 8, size_of([2,8]) = 7) = 8 |
| /// * Level 2: max(size_of([0,8]) = 9) = 9 |
| /// |
| void VectorShuffleTreeBuilder::computeShuffleTreeVectorSizes() { |
| // Compute vector size for each level. There are two direct cases: |
| // * First level: the vector size depends on the actual size of the input |
| // vectors and it's allowed to be non-uniform. We set it to 0. |
| // * Last level: the vector size is the output vector size so it doesn't |
| // have to be computed using intervals. |
| vectorSizePerLevel.front() = 0; |
| vectorSizePerLevel.back() = |
| cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements(); |
| |
| for (unsigned level = 1; level < numLevels - 1; ++level) { |
| const auto ¤tLevelIntervals = intervalsPerLevel[level]; |
| unsigned currentVectorSize = 1; |
| size_t numIntervals = currentLevelIntervals.size(); |
| for (size_t i = 0; i < numIntervals; ++i) { |
| const auto &interval = currentLevelIntervals[i]; |
| unsigned intervalSize = interval.second - interval.first + 1; |
| currentVectorSize = std::max(currentVectorSize, intervalSize); |
| } |
| assert(currentVectorSize > 0 && "vector size must be positive"); |
| vectorSizePerLevel[level] = currentVectorSize; |
| } |
| } |
| |
| void VectorShuffleTreeBuilder::dump() { |
| LLVM_DEBUG({ |
| unsigned indLv = 0; |
| |
| llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n"; |
| ++indLv; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n"; |
| ++indLv; |
| for (const auto &toElemsOp : toElemsDefs) |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n"; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n"; |
| --indLv; |
| |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* Total levels: " << numLevels << "\n"; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* Vector sizes per level: "; |
| llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs()); |
| llvm::dbgs() << "\n"; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* Input intervals per level:\n"; |
| ++indLv; |
| for (const auto &[level, intervals] : llvm::enumerate(intervalsPerLevel)) { |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level |
| << ": "; |
| llvm::interleaveComma(intervals, llvm::dbgs(), |
| [](const Interval &interval) { |
| llvm::dbgs() << "[" << interval.first << "," |
| << interval.second << "]"; |
| }); |
| llvm::dbgs() << "\n"; |
| } |
| }); |
| } |
| |
| /// Compute the shuffle tree configuration for the given `vector.to_elements` + |
| /// `vector.from_elements` input sequence. This method builds a balanced binary |
| /// shuffle tree that combines pairs of vectors at each level. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// |
| /// build a tree that looks like: |
| /// |
| /// %2 %1 %0 %0 |
| /// \ / \ / |
| /// %2_1 = vector.shuffle %0_0 = vector.shuffle |
| /// \ / |
| /// %2_1_0_0 =vector.shuffle |
| /// |
| /// The actual representation of the shuffle tree configuration is based on |
| /// intervals of each vector at each level of the shuffle tree (i.e., %2, %1, |
| /// %0, %0, %2_1, %0_0 and %2_1_0_0) and the vector size for each level. For |
| /// further details on intervals and vector size computation, please, take a |
| /// look at the corresponding utility functions. |
| LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() { |
| // Initialize shuffle tree information based on its size. For the number of |
| // levels, we add one to account for the input `vector.to_elements` as one |
| // tree level. We need the std::max(1) to account for a single element input. |
| numLevels = 1u + std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size())); |
| vectorSizePerLevel.resize(numLevels, 0); |
| intervalsPerLevel.reserve(numLevels); |
| |
| computeShuffleTreeIntervals(); |
| computeShuffleTreeVectorSizes(); |
| dump(); |
| |
| return success(); |
| } |
| |
| // ===---------------------------------------------------------------------===// |
| // Shuffle Tree Code Generation Utilities. |
| // ===---------------------------------------------------------------------===// |
| |
| /// Compute the permutation mask for shuffling two input `vector.to_elements` |
| /// ops. The permutation mask is the mapping of the vector elements to their |
| /// final position in the output vector, relative to the intermediate output |
| /// vector of the `vector.shuffle` operation combining the two inputs. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// |
| /// => |
| /// |
| /// // Level 1, vector length = 8 |
| /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] |
| /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] |
| /// |
| /// TODO: Implement mask compression to reduce the number of intermediate poison |
| /// values. |
| static SmallVector<int64_t> computePermutationShuffleMask( |
| ToElementsOp toElementOp0, const Interval &interval0, |
| ToElementsOp toElementOp1, const Interval &interval1, |
| FromElementsOp fromElemsOp, unsigned outputVectorSize) { |
| SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex); |
| unsigned inputVectorSize = |
| toElementOp0.getSource().getType().getNumElements(); |
| |
| for (const auto &[inputIdx, element] : |
| llvm::enumerate(fromElemsOp.getElements())) { |
| auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp()); |
| // Match `vector.from_elements` operands to the two input ops. |
| if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1) |
| continue; |
| |
| // The permutation value for a particular operand is the ordinal position of |
| // the operand in the `vector.to_elements` list of results. |
| unsigned permVal = cast<OpResult>(element).getResultNumber(); |
| unsigned maskIdx = inputIdx; |
| |
| // The mask index is the ordinal position of the operand in |
| // `vector.from_elements` operand list. We make this position relative to |
| // the output interval resulting from combining the two input intervals. |
| if (currentToElemOp == toElementOp0) { |
| maskIdx -= interval0.first; |
| } else { |
| // currentToElemOp == toElementOp1 |
| unsigned intervalOffset = interval1.first - interval0.first; |
| maskIdx += intervalOffset - interval1.first; |
| permVal += inputVectorSize; |
| } |
| |
| mask[maskIdx] = permVal; |
| } |
| |
| LLVM_DEBUG({ |
| unsigned indLv = 1; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: ["; |
| llvm::interleaveComma(mask, llvm::dbgs()); |
| llvm::dbgs() << "]\n"; |
| ++indLv; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* Combining: " << toElementOp0 << " and " << toElementOp1 |
| << "\n"; |
| }); |
| |
| return mask; |
| } |
| |
| /// Compute the propagation shuffle mask for combining two intermediate shuffle |
| /// operations of the tree. The propagation shuffle mask is the mapping of the |
| /// intermediate vector elements, which have already been shuffled to their |
| /// relative output position using the mask generated by |
| /// `computePermutationShuffleMask`, to their next position in the tree. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// |
| /// // Level 1, vector length = 8 |
| /// %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6] |
| /// %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1] |
| /// |
| /// => |
| /// |
| /// // Level 2, vector length = 9 |
| /// PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| /// |
| /// TODO: Implement mask compression to reduce the number of intermediate poison |
| /// values. |
| static SmallVector<int64_t> computePropagationShuffleMask( |
| ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp, |
| const Interval &rhsInterval, unsigned outputVectorSize) { |
| ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask(); |
| ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask(); |
| unsigned inputVectorSize = lhsShuffleMask.size(); |
| assert(inputVectorSize == rhsShuffleMask.size() && |
| "Expected both shuffle masks to have the same size"); |
| |
| bool hasSameInput = lhsShuffleOp == rhsShuffleOp; |
| unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first; |
| SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex); |
| |
| // Propagate any element from the input mask that is not poison. For the RHS |
| // vector, offset mask index by the distance between the intervals. |
| for (unsigned i = 0; i < inputVectorSize; ++i) { |
| if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex) |
| mask[i] = i; |
| |
| if (hasSameInput) |
| continue; |
| |
| unsigned rhsIdx = i + lhsRhsOffset; |
| if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) { |
| assert(rhsIdx < outputVectorSize && "RHS index out of bounds"); |
| assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set"); |
| mask[rhsIdx] = i + inputVectorSize; |
| } |
| } |
| |
| LLVM_DEBUG({ |
| unsigned indLv = 1; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* Propagation shuffle mask computation:\n"; |
| ++indLv; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* LHS shuffle op: " << lhsShuffleOp << "\n"; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) |
| << "* RHS shuffle op: " << rhsShuffleOp << "\n"; |
| llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: ["; |
| llvm::interleaveComma(mask, llvm::dbgs()); |
| llvm::dbgs() << "]\n"; |
| }); |
| |
| return mask; |
| } |
| |
| /// Materialize the pre-computed shuffle tree configuration in the IR by |
| /// generating the corresponding `vector.shuffle` ops. |
| /// |
| /// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>: |
| /// |
| /// %0:5 = vector.to_elements %a : vector<5xf32> |
| /// %1:5 = vector.to_elements %b : vector<5xf32> |
| /// %2:5 = vector.to_elements %c : vector<5xf32> |
| /// %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, |
| /// %2#2, %2#0, %1#1, %0#4 : vector<9xf32> |
| /// |
| /// with the pre-computed shuffle tree configuration: |
| /// |
| /// * Vector sizes per level: 0, 8, 9 |
| /// * Input intervals per level: |
| /// * Level 0: [0,6], [1,7], [2,8], [2,8] |
| /// * Level 1: [0,7], [2,8] |
| /// * Level 2: [0,8] |
| /// |
| /// => |
| /// |
| /// %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6] |
| /// : vector<5xf32>, vector<5xf32> |
| /// %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1] |
| /// : vector<5xf32>, vector<5xf32> |
| /// %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14] |
| /// : vector<8xf32>, vector<8xf32> |
| /// |
| /// The code generation consists of combining pairs of vectors at each level of |
| /// the tree, using the pre-computed tree intervals and vector sizes. The |
| /// algorithm generates two kinds of shuffle masks: |
| /// * Permutation masks: computed for the first level of the tree and permute |
| /// the input vector elements to their relative position in the final |
| /// output. |
| /// * Propagation masks: computed for subsequent levels and propagate the |
| /// elements to the next level without permutation. |
| /// |
| /// For further details on the shuffle mask computation, please, take a look at |
| /// the corresponding `computePermutationShuffleMask` and |
| /// `computePropagationShuffleMask` functions. |
| /// |
| Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) { |
| LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n"); |
| |
| // Initialize work list with the `vector.to_elements` sources. |
| SmallVector<Value> levelInputs; |
| llvm::transform(toElemsDefs, std::back_inserter(levelInputs), |
| [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); }); |
| |
| // Build shuffle tree by combining pairs of vectors (represented by their |
| // corresponding intervals) in one level and producing a new vector with the |
| // next level's vector length. Skip the interval from the last tree level |
| // (actual shuffle tree output) as it doesn't have to be combined with |
| // anything else. |
| Location loc = fromElemsOp.getLoc(); |
| unsigned currentLevel = 0; |
| for (const auto &[nextLevelVectorSize, intervals] : |
| llvm::zip_equal(ArrayRef(vectorSizePerLevel).drop_front(), |
| ArrayRef(intervalsPerLevel).drop_back())) { |
| |
| duplicateLastIfOdd(levelInputs); |
| |
| LLVM_DEBUG(llvm::dbgs() << llvm::indent(1, kIndScale) |
| << "* Processing level " << currentLevel |
| << " (output vector size: " << nextLevelVectorSize |
| << ", # inputs: " << levelInputs.size() << ")\n"); |
| |
| // Process level input vectors in pairs. |
| SmallVector<Value> levelOutputs; |
| for (size_t i = 0, numLevelInputs = levelInputs.size(); i < numLevelInputs; |
| i += 2) { |
| Value lhsVector = levelInputs[i]; |
| Value rhsVector = levelInputs[i + 1]; |
| const Interval &lhsInterval = intervals[i]; |
| const Interval &rhsInterval = intervals[i + 1]; |
| |
| // For the first level of the tree, permute the vector elements to their |
| // relative position in the final output. For subsequent levels, we |
| // propagate the elements to the next level without permutation. |
| SmallVector<int64_t> shuffleMask; |
| if (currentLevel == 0) { |
| shuffleMask = computePermutationShuffleMask( |
| toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval, |
| fromElemsOp, nextLevelVectorSize); |
| } else { |
| auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp()); |
| auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp()); |
| shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval, |
| rhsShuffleOp, rhsInterval, |
| nextLevelVectorSize); |
| } |
| |
| Value shuffleVal = vector::ShuffleOp::create(rewriter, loc, lhsVector, |
| rhsVector, shuffleMask); |
| levelOutputs.push_back(shuffleVal); |
| } |
| |
| levelInputs = std::move(levelOutputs); |
| ++currentLevel; |
| } |
| |
| assert(levelInputs.size() == 1 && "Should have exactly one result"); |
| return levelInputs.front(); |
| } |
| |
| /// Gather and unique all the `vector.to_elements` operations that feed the |
| /// `vector.from_elements` operation. The `vector.to_elements` operations are |
| /// returned in order of appearance in the `vector.from_elements`'s operand |
| /// list. |
| static LogicalResult |
| getToElementsDefiningOps(FromElementsOp fromElemsOp, |
| SmallVectorImpl<ToElementsOp> &toElemsDefs) { |
| SetVector<ToElementsOp> toElemsDefsSet; |
| for (Value element : fromElemsOp.getElements()) { |
| auto toElemsOp = element.getDefiningOp<ToElementsOp>(); |
| if (!toElemsOp) |
| return failure(); |
| toElemsDefsSet.insert(toElemsOp); |
| } |
| |
| toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end()); |
| return success(); |
| } |
| |
| /// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into |
| /// a tree of `vector.shuffle` operations. Only 1-D input vectors are supported |
| /// for now. |
| struct ToFromElementsToShuffleTreeRewrite final |
| : OpRewritePattern<vector::FromElementsOp> { |
| |
| using Base::Base; |
| |
| LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, |
| PatternRewriter &rewriter) const override { |
| VectorType resultType = fromElemsOp.getType(); |
| if (resultType.getRank() != 1) |
| return rewriter.notifyMatchFailure( |
| fromElemsOp, |
| "multi-dimensional output vectors are not supported yet"); |
| if (resultType.isScalable()) |
| return rewriter.notifyMatchFailure( |
| fromElemsOp, |
| "'vector.from_elements' does not support scalable vectors"); |
| |
| // Gather all the `vector.to_elements` operations that feed the |
| // `vector.from_elements` operation. Other op definitions are not supported. |
| SmallVector<ToElementsOp> toElemsDefs; |
| if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs))) |
| return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources"); |
| |
| if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) { |
| return toElemsOp.getSource().getType().getRank() != 1; |
| })) { |
| return rewriter.notifyMatchFailure( |
| fromElemsOp, "multi-dimensional input vectors are not supported yet"); |
| } |
| |
| if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) { |
| return !toElemsOp.getSource().getType().hasRank(); |
| })) { |
| return rewriter.notifyMatchFailure(fromElemsOp, |
| "0-D vectors are not supported"); |
| } |
| |
| // Avoid generating a shuffle tree for trivial `vector.to_elements` -> |
| // `vector.from_elements` forwarding cases that do not require shuffling. |
| if (toElemsDefs.size() == 1) { |
| ToElementsOp toElemsOp0 = toElemsDefs.front(); |
| if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) { |
| return rewriter.notifyMatchFailure( |
| fromElemsOp, "trivial forwarding case does not require shuffling"); |
| } |
| } |
| |
| VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs); |
| if (failed(shuffleTreeBuilder.computeShuffleTree())) |
| return rewriter.notifyMatchFailure(fromElemsOp, |
| "failed to compute shuffle tree"); |
| |
| Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter); |
| rewriter.replaceOp(fromElemsOp, finalShuffle); |
| return success(); |
| } |
| }; |
| |
| struct LowerVectorToFromElementsToShuffleTreePass |
| : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase< |
| LowerVectorToFromElementsToShuffleTreePass> { |
| |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| populateVectorToFromElementsToShuffleTreePatterns(patterns); |
| |
| if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| return signalPassFailure(); |
| } |
| }; |
| |
| } // namespace |
| |
| void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( |
| RewritePatternSet &patterns, PatternBenefit benefit) { |
| patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(), |
| benefit); |
| } |