| //===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements rewrite patterns for the permutation_map attribute of |
| // vector.transfer operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Vector/VectorTransforms.h" |
| #include "mlir/Interfaces/VectorInterfaces.h" |
| |
| using namespace mlir; |
| using namespace mlir::vector; |
| |
| /// Transpose a vector transfer op's `in_bounds` attribute according to given |
| /// indices. |
| static ArrayAttr |
| transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, |
| const SmallVector<unsigned> &permutation) { |
| SmallVector<bool> newInBoundsValues; |
| for (unsigned pos : permutation) |
| newInBoundsValues.push_back( |
| attr.getValue()[pos].cast<BoolAttr>().getValue()); |
| return builder.getBoolArrayAttr(newInBoundsValues); |
| } |
| /// Lower transfer_read op with permutation into a transfer_read with a |
| /// permutation map composed of leading zeros followed by a minor identiy + |
| /// vector.transpose op. |
| /// Ex: |
| /// vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2) -> (0, d1) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2) -> (d1, 0) |
| /// vector.transpose %v, [1, 0] |
| /// |
| /// vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) |
| /// vector.transpose %v, [0, 1, 3, 2, 4] |
| /// Note that an alternative is to transform it to linalg.transpose + |
| /// vector.transfer_read to do the transpose in memory instead. |
| struct TransferReadPermutationLowering |
| : public OpRewritePattern<vector::TransferReadOp> { |
| using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferReadOp op, |
| PatternRewriter &rewriter) const override { |
| SmallVector<unsigned> permutation; |
| AffineMap map = op.permutation_map(); |
| if (map.getNumResults() == 0) |
| return failure(); |
| if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) |
| return failure(); |
| AffineMap permutationMap = |
| map.getPermutationMap(permutation, op.getContext()); |
| if (permutationMap.isIdentity()) |
| return failure(); |
| |
| permutationMap = map.getPermutationMap(permutation, op.getContext()); |
| // Caluclate the map of the new read by applying the inverse permutation. |
| permutationMap = inversePermutation(permutationMap); |
| AffineMap newMap = permutationMap.compose(map); |
| // Apply the reverse transpose to deduce the type of the transfer_read. |
| ArrayRef<int64_t> originalShape = op.getVectorType().getShape(); |
| SmallVector<int64_t> newVectorShape(originalShape.size()); |
| for (auto pos : llvm::enumerate(permutation)) { |
| newVectorShape[pos.value()] = originalShape[pos.index()]; |
| } |
| |
| // Transpose mask operand. |
| Value newMask; |
| if (op.mask()) { |
| // Remove unused dims from the permutation map. E.g.: |
| // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) |
| // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) |
| auto comp = compressUnusedDims(map); |
| // Get positions of remaining result dims. |
| // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) |
| // maskTransposeIndices = [ 2, 1, 0] |
| SmallVector<int64_t> maskTransposeIndices; |
| for (unsigned i = 0; i < comp.getNumResults(); ++i) { |
| if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>()) |
| maskTransposeIndices.push_back(expr.getPosition()); |
| } |
| |
| newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(), |
| maskTransposeIndices); |
| } |
| |
| // Transpose in_bounds attribute. |
| ArrayAttr newInBounds = |
| op.in_bounds() ? transposeInBoundsAttr( |
| rewriter, op.in_bounds().getValue(), permutation) |
| : ArrayAttr(); |
| |
| // Generate new transfer_read operation. |
| VectorType newReadType = |
| VectorType::get(newVectorShape, op.getVectorType().getElementType()); |
| Value newRead = rewriter.create<vector::TransferReadOp>( |
| op.getLoc(), newReadType, op.source(), op.indices(), newMap, |
| op.padding(), newMask, newInBounds); |
| |
| // Transpose result of transfer_read. |
| SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); |
| rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead, |
| transposePerm); |
| return success(); |
| } |
| }; |
| |
| /// Lower transfer_write op with permutation into a transfer_write with a |
| /// minor identity permutation map. (transfer_write ops cannot have broadcasts.) |
| /// Ex: |
| /// vector.transfer_write %v ... |
| /// permutation_map: (d0, d1, d2) -> (d2, d0, d1) |
| /// into: |
| /// %tmp = vector.transpose %v, [2, 0, 1] |
| /// vector.transfer_write %tmp ... |
| /// permutation_map: (d0, d1, d2) -> (d0, d1, d2) |
| /// |
| /// vector.transfer_write %v ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d3, d2) |
| /// into: |
| /// %tmp = vector.transpose %v, [1, 0] |
| /// %v = vector.transfer_write %tmp ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d2, d3) |
| struct TransferWritePermutationLowering |
| : public OpRewritePattern<vector::TransferWriteOp> { |
| using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferWriteOp op, |
| PatternRewriter &rewriter) const override { |
| if (op.isZeroD()) |
| return failure(); |
| |
| SmallVector<unsigned> permutation; |
| AffineMap map = op.permutation_map(); |
| if (map.isMinorIdentity()) |
| return failure(); |
| if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) |
| return failure(); |
| |
| // Remove unused dims from the permutation map. E.g.: |
| // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) |
| // comp = (d0, d1, d2) -> (d2, d0, d1) |
| auto comp = compressUnusedDims(map); |
| // Get positions of remaining result dims. |
| SmallVector<int64_t> indices; |
| llvm::transform(comp.getResults(), std::back_inserter(indices), |
| [](AffineExpr expr) { |
| return expr.dyn_cast<AffineDimExpr>().getPosition(); |
| }); |
| |
| // Transpose mask operand. |
| Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>( |
| op.getLoc(), op.mask(), indices) |
| : Value(); |
| |
| // Transpose in_bounds attribute. |
| ArrayAttr newInBounds = |
| op.in_bounds() ? transposeInBoundsAttr( |
| rewriter, op.in_bounds().getValue(), permutation) |
| : ArrayAttr(); |
| |
| // Generate new transfer_write operation. |
| Value newVec = |
| rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices); |
| auto newMap = AffineMap::getMinorIdentityMap( |
| map.getNumDims(), map.getNumResults(), rewriter.getContext()); |
| rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( |
| op, Type(), newVec, op.source(), op.indices(), newMap, newMask, |
| newInBounds); |
| |
| return success(); |
| } |
| }; |
| |
| /// Lower transfer_read op with broadcast in the leading dimensions into |
| /// transfer_read of lower rank + vector.broadcast. |
| /// Ex: vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) |
| /// into: |
| /// %v = vector.transfer_read ... |
| /// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) |
| /// vector.broadcast %v |
| struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> { |
| using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(vector::TransferReadOp op, |
| PatternRewriter &rewriter) const override { |
| AffineMap map = op.permutation_map(); |
| unsigned numLeadingBroadcast = 0; |
| for (auto expr : map.getResults()) { |
| auto dimExpr = expr.dyn_cast<AffineConstantExpr>(); |
| if (!dimExpr || dimExpr.getValue() != 0) |
| break; |
| numLeadingBroadcast++; |
| } |
| // If there are no leading zeros in the map there is nothing to do. |
| if (numLeadingBroadcast == 0) |
| return failure(); |
| VectorType originalVecType = op.getVectorType(); |
| unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast; |
| // Calculate new map, vector type and masks without the leading zeros. |
| AffineMap newMap = AffineMap::get( |
| map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank), |
| op.getContext()); |
| // Only remove the leading zeros if the rest of the map is a minor identity |
| // with broadasting. Otherwise we first want to permute the map. |
| if (!newMap.isMinorIdentityWithBroadcasting()) |
| return failure(); |
| |
| // TODO: support zero-dimension vectors natively. See: |
| // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. |
| // In the meantime, lower these to a scalar load when they pop up. |
| if (reducedShapeRank == 0) { |
| Value newRead; |
| if (op.getShapedType().isa<TensorType>()) { |
| newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(), |
| op.indices()); |
| } else { |
| newRead = rewriter.create<memref::LoadOp>( |
| op.getLoc(), originalVecType.getElementType(), op.source(), |
| op.indices()); |
| } |
| rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, |
| newRead); |
| return success(); |
| } |
| SmallVector<int64_t> newShape = llvm::to_vector<4>( |
| originalVecType.getShape().take_back(reducedShapeRank)); |
| // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. |
| if (newShape.empty()) |
| return failure(); |
| VectorType newReadType = |
| VectorType::get(newShape, originalVecType.getElementType()); |
| ArrayAttr newInBounds = |
| op.in_bounds() |
| ? rewriter.getArrayAttr( |
| op.in_boundsAttr().getValue().take_back(reducedShapeRank)) |
| : ArrayAttr(); |
| Value newRead = rewriter.create<vector::TransferReadOp>( |
| op.getLoc(), newReadType, op.source(), op.indices(), newMap, |
| op.padding(), op.mask(), newInBounds); |
| rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType, |
| newRead); |
| return success(); |
| } |
| }; |
| |
| void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( |
| RewritePatternSet &patterns) { |
| patterns.add<TransferReadPermutationLowering, |
| TransferWritePermutationLowering, TransferOpReduceRank>( |
| patterns.getContext()); |
| } |