blob: 36725e03ae09e752e75408cdc3aa6bfb98a8100f [file] [log] [blame]
//===- VectorTransferPermutationMapRewritePatterns.cpp - Xfer map rewrite -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements rewrite patterns for the permutation_map attribute of
// vector.transfer operations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
#include "mlir/Interfaces/VectorInterfaces.h"
using namespace mlir;
using namespace mlir::vector;
/// Transpose a vector transfer op's `in_bounds` attribute according to given
/// indices.
static ArrayAttr
transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr,
const SmallVector<unsigned> &permutation) {
SmallVector<bool> newInBoundsValues;
for (unsigned pos : permutation)
newInBoundsValues.push_back(
attr.getValue()[pos].cast<BoolAttr>().getValue());
return builder.getBoolArrayAttr(newInBoundsValues);
}
/// Lower transfer_read op with permutation into a transfer_read with a
/// permutation map composed of leading zeros followed by a minor identiy +
/// vector.transpose op.
/// Ex:
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (0, d1)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2) -> (d1, 0)
/// vector.transpose %v, [1, 0]
///
/// vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3)
/// vector.transpose %v, [0, 1, 3, 2, 4]
/// Note that an alternative is to transform it to linalg.transpose +
/// vector.transfer_read to do the transpose in memory instead.
struct TransferReadPermutationLowering
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return failure();
SmallVector<unsigned> permutation;
AffineMap map = op.permutation_map();
if (map.getNumResults() == 0)
return failure();
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
return failure();
AffineMap permutationMap =
map.getPermutationMap(permutation, op.getContext());
if (permutationMap.isIdentity())
return failure();
permutationMap = map.getPermutationMap(permutation, op.getContext());
// Caluclate the map of the new read by applying the inverse permutation.
permutationMap = inversePermutation(permutationMap);
AffineMap newMap = permutationMap.compose(map);
// Apply the reverse transpose to deduce the type of the transfer_read.
ArrayRef<int64_t> originalShape = op.getVectorType().getShape();
SmallVector<int64_t> newVectorShape(originalShape.size());
for (auto pos : llvm::enumerate(permutation)) {
newVectorShape[pos.value()] = originalShape[pos.index()];
}
// Transpose mask operand.
Value newMask;
if (op.mask()) {
// Remove unused dims from the permutation map. E.g.:
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2)
// comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0)
auto comp = compressUnusedDims(map);
// Get positions of remaining result dims.
// E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0)
// maskTransposeIndices = [ 2, 1, 0]
SmallVector<int64_t> maskTransposeIndices;
for (unsigned i = 0; i < comp.getNumResults(); ++i) {
if (auto expr = comp.getResult(i).dyn_cast<AffineDimExpr>())
maskTransposeIndices.push_back(expr.getPosition());
}
newMask = rewriter.create<vector::TransposeOp>(op.getLoc(), op.mask(),
maskTransposeIndices);
}
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
op.in_bounds() ? transposeInBoundsAttr(
rewriter, op.in_bounds().getValue(), permutation)
: ArrayAttr();
// Generate new transfer_read operation.
VectorType newReadType =
VectorType::get(newVectorShape, op.getVectorType().getElementType());
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.source(), op.indices(),
AffineMapAttr::get(newMap), op.padding(), newMask, newInBoundsAttr);
// Transpose result of transfer_read.
SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
rewriter.replaceOpWithNewOp<vector::TransposeOp>(op, newRead,
transposePerm);
return success();
}
};
/// Lower transfer_write op with permutation into a transfer_write with a
/// minor identity permutation map. (transfer_write ops cannot have broadcasts.)
/// Ex:
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2) -> (d2, d0, d1)
/// into:
/// %tmp = vector.transpose %v, [2, 0, 1]
/// vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2) -> (d0, d1, d2)
///
/// vector.transfer_write %v ...
/// permutation_map: (d0, d1, d2, d3) -> (d3, d2)
/// into:
/// %tmp = vector.transpose %v, [1, 0]
/// %v = vector.transfer_write %tmp ...
/// permutation_map: (d0, d1, d2, d3) -> (d2, d3)
struct TransferWritePermutationLowering
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp op,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return failure();
SmallVector<unsigned> permutation;
AffineMap map = op.permutation_map();
if (map.isMinorIdentity())
return failure();
if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
return failure();
// Remove unused dims from the permutation map. E.g.:
// E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4)
// comp = (d0, d1, d2) -> (d2, d0, d1)
auto comp = compressUnusedDims(map);
// Get positions of remaining result dims.
SmallVector<int64_t> indices;
llvm::transform(comp.getResults(), std::back_inserter(indices),
[](AffineExpr expr) {
return expr.dyn_cast<AffineDimExpr>().getPosition();
});
// Transpose mask operand.
Value newMask = op.mask() ? rewriter.create<vector::TransposeOp>(
op.getLoc(), op.mask(), indices)
: Value();
// Transpose in_bounds attribute.
ArrayAttr newInBoundsAttr =
op.in_bounds() ? transposeInBoundsAttr(
rewriter, op.in_bounds().getValue(), permutation)
: ArrayAttr();
// Generate new transfer_write operation.
Value newVec =
rewriter.create<vector::TransposeOp>(op.getLoc(), op.vector(), indices);
auto newMap = AffineMap::getMinorIdentityMap(
map.getNumDims(), map.getNumResults(), rewriter.getContext());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, Type(), newVec, op.source(), op.indices(),
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
return success();
}
};
/// Lower transfer_read op with broadcast in the leading dimensions into
/// transfer_read of lower rank + vector.broadcast.
/// Ex: vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3)
/// into:
/// %v = vector.transfer_read ...
/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3)
/// vector.broadcast %v
struct TransferOpReduceRank : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp op,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (op.getTransferRank() == 0)
return failure();
AffineMap map = op.permutation_map();
unsigned numLeadingBroadcast = 0;
for (auto expr : map.getResults()) {
auto dimExpr = expr.dyn_cast<AffineConstantExpr>();
if (!dimExpr || dimExpr.getValue() != 0)
break;
numLeadingBroadcast++;
}
// If there are no leading zeros in the map there is nothing to do.
if (numLeadingBroadcast == 0)
return failure();
VectorType originalVecType = op.getVectorType();
unsigned reducedShapeRank = originalVecType.getRank() - numLeadingBroadcast;
// Calculate new map, vector type and masks without the leading zeros.
AffineMap newMap = AffineMap::get(
map.getNumDims(), 0, map.getResults().take_back(reducedShapeRank),
op.getContext());
// Only remove the leading zeros if the rest of the map is a minor identity
// with broadasting. Otherwise we first want to permute the map.
if (!newMap.isMinorIdentityWithBroadcasting())
return failure();
// TODO: support zero-dimension vectors natively. See:
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
// In the meantime, lower these to a scalar load when they pop up.
if (reducedShapeRank == 0) {
Value newRead;
if (op.getShapedType().isa<TensorType>()) {
newRead = rewriter.create<tensor::ExtractOp>(op.getLoc(), op.source(),
op.indices());
} else {
newRead = rewriter.create<memref::LoadOp>(
op.getLoc(), originalVecType.getElementType(), op.source(),
op.indices());
}
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
}
SmallVector<int64_t> newShape = llvm::to_vector<4>(
originalVecType.getShape().take_back(reducedShapeRank));
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
if (newShape.empty())
return failure();
VectorType newReadType =
VectorType::get(newShape, originalVecType.getElementType());
ArrayAttr newInBoundsAttr =
op.in_bounds()
? rewriter.getArrayAttr(
op.in_boundsAttr().getValue().take_back(reducedShapeRank))
: ArrayAttr();
Value newRead = rewriter.create<vector::TransferReadOp>(
op.getLoc(), newReadType, op.source(), op.indices(),
AffineMapAttr::get(newMap), op.padding(), op.mask(), newInBoundsAttr);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, originalVecType,
newRead);
return success();
}
};
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
RewritePatternSet &patterns) {
patterns.add<TransferReadPermutationLowering,
TransferWritePermutationLowering, TransferOpReduceRank>(
patterns.getContext());
}