blob: d4d07c7eadc7729461f57b8e07c5324e25b472ce [file] [log] [blame]
//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functions concerned with optimizing transfer_read and
// transfer_write ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "vector-transfer-opt"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
/// Return the ancestor op in the region or nullptr if the region is not
/// an ancestor of the op.
static Operation *findAncestorOpInRegion(Region *region, Operation *op) {
for (; op != nullptr && op->getParentRegion() != region;
op = op->getParentOp())
;
return op;
}
namespace {
class TransferOptimization {
public:
TransferOptimization(RewriterBase &rewriter, Operation *op)
: rewriter(rewriter), dominators(op), postDominators(op) {}
void deadStoreOp(vector::TransferWriteOp);
void storeToLoadForwarding(vector::TransferReadOp);
void removeDeadOp() {
for (Operation *op : opToErase)
rewriter.eraseOp(op);
opToErase.clear();
}
private:
RewriterBase &rewriter;
bool isReachable(Operation *start, Operation *dest);
DominanceInfo dominators;
PostDominanceInfo postDominators;
std::vector<Operation *> opToErase;
};
} // namespace
/// Return true if there is a path from start operation to dest operation,
/// otherwise return false. The operations have to be in the same region.
bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
assert(start->getParentRegion() == dest->getParentRegion() &&
"This function only works for ops i the same region");
// Simple case where the start op dominate the destination.
if (dominators.dominates(start, dest))
return true;
return start->getBlock()->isReachable(dest->getBlock());
}
/// For transfer_write to overwrite fully another transfer_write must:
/// 1. Access the same memref with the same indices and vector type.
/// 2. Post-dominate the other transfer_write operation.
/// If several candidates are available, one must be post-dominated by all the
/// others since they are all post-dominating the same transfer_write. We only
/// consider the transfer_write post-dominated by all the other candidates as
/// this will be the first transfer_write executed after the potentially dead
/// transfer_write.
/// If we found such an overwriting transfer_write we know that the original
/// transfer_write is dead if all reads that can be reached from the potentially
/// dead transfer_write are dominated by the overwriting transfer_write.
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
<< "\n");
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (isa<ViewLikeOpInterface>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user))
continue;
if (user == write.getOperation())
continue;
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
// Check candidate that can override the store.
if (memref::isSameViewOrTrivialAlias(
cast<MemrefValue>(nextWrite.getBase()),
cast<MemrefValue>(write.getBase())) &&
checkSameValueWAW(nextWrite, write) &&
postDominators.postDominates(nextWrite, write)) {
if (firstOverwriteCandidate == nullptr ||
postDominators.postDominates(firstOverwriteCandidate, nextWrite))
firstOverwriteCandidate = nextWrite;
else
assert(
postDominators.postDominates(nextWrite, firstOverwriteCandidate));
continue;
}
}
if (auto transferOp = dyn_cast<VectorTransferOpInterface>(user)) {
// Don't need to consider disjoint accesses.
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(transferOp.getOperation()),
/*testDynamicValueUsingBounds=*/true))
continue;
}
blockingAccesses.push_back(user);
}
if (firstOverwriteCandidate == nullptr)
return;
Region *topRegion = firstOverwriteCandidate->getParentRegion();
Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
assert(writeAncestor &&
"write op should be recursively part of the top region");
for (Operation *access : blockingAccesses) {
Operation *accessAncestor = findAncestorOpInRegion(topRegion, access);
// TODO: if the access and write have the same ancestor we could recurse in
// the region to know if the access is reachable with more precision.
if (accessAncestor == nullptr ||
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
<< *accessAncestor << "\n");
return;
}
}
LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
<< " overwritten by: " << *firstOverwriteCandidate << "\n");
opToErase.push_back(write.getOperation());
}
/// A transfer_write candidate to storeToLoad forwarding must:
/// 1. Access the same memref with the same indices and vector type as the
/// transfer_read.
/// 2. Dominate the transfer_read operation.
/// If several candidates are available, one must be dominated by all the others
/// since they are all dominating the same transfer_read. We only consider the
/// transfer_write dominated by all the other candidates as this will be the
/// last transfer_write executed before the transfer_read.
/// If we found such a candidate we can do the forwarding if all the other
/// potentially aliasing ops that may reach the transfer_read are post-dominated
/// by the transfer_write.
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
<< "\n");
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (isa<ViewLikeOpInterface>(user)) {
users.append(user->getUsers().begin(), user->getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
continue;
if (auto write = dyn_cast<vector::TransferWriteOp>(user)) {
// If there is a write, but we can prove that it is disjoint we can ignore
// the write.
if (vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(write.getOperation()),
cast<VectorTransferOpInterface>(read.getOperation()),
/*testDynamicValueUsingBounds=*/true))
continue;
if (memref::isSameViewOrTrivialAlias(
cast<MemrefValue>(read.getBase()),
cast<MemrefValue>(write.getBase())) &&
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
lastwrite = write;
else
assert(dominators.dominates(write, lastwrite));
continue;
}
}
blockingWrites.push_back(user);
}
if (lastwrite == nullptr)
return;
Region *topRegion = lastwrite->getParentRegion();
Operation *readAncestor = findAncestorOpInRegion(topRegion, read);
assert(readAncestor &&
"read op should be recursively part of the top region");
for (Operation *write : blockingWrites) {
Operation *writeAncestor = findAncestorOpInRegion(topRegion, write);
// TODO: if the store and read have the same ancestor we could recurse in
// the region to know if the read is reachable with more precision.
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
<< *write << "\n");
return;
}
}
LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
<< " to: " << *read.getOperation() << "\n");
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
/// Converts OpFoldResults to int64_t shape without unit dims.
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
SmallVector<int64_t> reducedShape;
for (const auto size : mixedSizes) {
if (llvm::dyn_cast_if_present<Value>(size)) {
reducedShape.push_back(ShapedType::kDynamic);
continue;
}
auto value = cast<IntegerAttr>(cast<Attribute>(size)).getValue();
if (value == 1)
continue;
reducedShape.push_back(value.getSExtValue());
}
return reducedShape;
}
/// Drops unit dimensions from the input MemRefType.
static MemRefType dropUnitDims(MemRefType inputType,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
auto targetShape = getReducedShape(sizes);
MemRefType rankReducedType = memref::SubViewOp::inferRankReducedResultType(
targetShape, inputType, offsets, sizes, strides);
return rankReducedType.canonicalizeStridedLayout();
}
/// Creates a rank-reducing memref.subview op that drops unit dims from its
/// input. Or just returns the input if it was already without unit dims.
static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
mlir::Location loc,
Value input) {
MemRefType inputType = cast<MemRefType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(),
rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
SmallVector<OpFoldResult> strides(inputType.getRank(),
rewriter.getIndexAttr(1));
MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
if (resultType.canonicalizeStridedLayout() ==
inputType.canonicalizeStridedLayout())
return input;
return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
sizes, strides);
}
/// Returns the number of dims that aren't unit dims.
static int getReducedRank(ArrayRef<int64_t> shape) {
return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
}
/// Trims non-scalable one dimensions from `oldType` and returns the result
/// type.
static VectorType trimNonScalableUnitDims(VectorType oldType) {
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
continue;
newShape.push_back(dimSize);
newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
}
return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
}
// Rewrites vector.create_mask 'op' to drop non-scalable one dimensions.
static FailureOr<Value>
createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc,
vector::CreateMaskOp op) {
auto type = op.getType();
VectorType reducedType = trimNonScalableUnitDims(type);
if (reducedType.getRank() == type.getRank())
return failure();
SmallVector<Value> reducedOperands;
for (auto [dim, dimIsScalable, operand] : llvm::zip_equal(
type.getShape(), type.getScalableDims(), op.getOperands())) {
if (dim == 1 && !dimIsScalable) {
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = operand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
return failure();
continue;
}
reducedOperands.push_back(operand);
}
return rewriter
.create<vector::CreateMaskOp>(loc, reducedType, reducedOperands)
.getResult();
}
namespace {
/// Rewrites `vector.transfer_read` ops where the source has unit dims, by
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
: public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferReadOp.getBase();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor types.
if (!sourceType)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
// Check if the source shape can be further reduced.
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced source shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
Value maskOp = transferReadOp.getMask();
if (maskOp) {
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
transferReadOp, "unsupported mask op, only 'vector.create_mask' is "
"currently supported");
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
if (failed(rankReducedCreateMask))
return failure();
maskOp = *rankReducedCreateMask;
}
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newTransferReadOp = mlir::vector::maskOperation(
rewriter, newTransferReadOp, shapeCastMask);
}
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, vectorType, newTransferReadOp->getResults()[0]);
return shapeCast;
}
};
/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination)
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
: public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
using MaskableOpRewritePattern::MaskableOpRewritePattern;
FailureOr<Value>
matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
vector::MaskingOpInterface maskingOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getBase();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor type.
if (!sourceType)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
// Check if the destination shape can be further reduced.
int reducedRank = getReducedRank(sourceType.getShape());
if (reducedRank == sourceType.getRank())
return failure();
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
// out.
if (reducedRank == 0 && maskingOp)
return failure();
// Check if the reduced vector shape matches the reduced destination shape.
// Otherwise, this case is not supported yet.
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
Value maskOp = transferWriteOp.getMask();
if (maskOp) {
auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
transferWriteOp,
"unsupported mask op, only 'vector.create_mask' is "
"currently supported");
FailureOr<Value> rankReducedCreateMask =
createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
if (failed(rankReducedCreateMask))
return failure();
maskOp = *rankReducedCreateMask;
}
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
maskOp, rewriter.getBoolArrayAttr(inBounds));
if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
maskingOp.getMask());
newXferWrite =
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
}
if (transferWriteOp.hasPureTensorSemantics())
return newXferWrite->getResults()[0];
// With Memref semantics, there's no return value. Use empty value to signal
// success.
return Value();
}
};
} // namespace
/// Creates a memref.collapse_shape collapsing all inner dimensions of the
/// input starting at `firstDimToCollapse`.
static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc,
Value input, int64_t firstDimToCollapse) {
ShapedType inputType = cast<ShapedType>(input.getType());
if (inputType.getRank() == 1)
return input;
SmallVector<ReassociationIndices> reassociation;
for (int64_t i = 0; i < firstDimToCollapse; ++i)
reassociation.push_back(ReassociationIndices{i});
ReassociationIndices collapsedIndices;
for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i)
collapsedIndices.push_back(i);
reassociation.push_back(collapsedIndices);
return rewriter.create<memref::CollapseShapeOp>(loc, input, reassociation);
}
/// Returns the new indices that collapses the inner dimensions starting from
/// the `firstDimToCollapse` dimension.
static SmallVector<Value> getCollapsedIndices(RewriterBase &rewriter,
Location loc,
ArrayRef<int64_t> shape,
ValueRange indices,
int64_t firstDimToCollapse) {
assert(firstDimToCollapse < static_cast<int64_t>(indices.size()));
// If all the collapsed indices are zero then no extra logic is needed.
// Otherwise, a new offset/index has to be computed.
SmallVector<Value> indicesAfterCollapsing(
indices.begin(), indices.begin() + firstDimToCollapse);
SmallVector<Value> indicesToCollapse(indices.begin() + firstDimToCollapse,
indices.end());
if (llvm::all_of(indicesToCollapse, isZeroIndex)) {
indicesAfterCollapsing.push_back(indicesToCollapse[0]);
return indicesAfterCollapsing;
}
// Compute the remaining trailing index/offset required for reading from
// the collapsed memref:
//
// offset = 0
// for (i = firstDimToCollapse; i < outputRank; ++i)
// offset += sourceType.getDimSize(i) * transferReadOp.indices[i]
//
// For this example:
// %2 = vector.transfer_read/write %arg4[%c0, %arg0, %c0] (...) :
// memref<1x43x2xi32>, vector<1x2xi32>
// which would be collapsed to:
// %1 = vector.transfer_read/write %collapse_shape[%c0, %offset] (...) :
// memref<1x86xi32>, vector<2xi32>
// one would get the following offset:
// %offset = %arg0 * 43
OpFoldResult collapsedOffset =
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
auto collapsedStrides = computeSuffixProduct(
ArrayRef<int64_t>(shape.begin() + firstDimToCollapse, shape.end()));
// Compute the collapsed offset.
auto &&[collapsedExpr, collapsedVals] =
computeLinearIndex(collapsedOffset, collapsedStrides, indicesToCollapse);
collapsedOffset = affine::makeComposedFoldedAffineApply(
rewriter, loc, collapsedExpr, collapsedVals);
if (auto value = dyn_cast<Value>(collapsedOffset)) {
indicesAfterCollapsing.push_back(value);
} else {
indicesAfterCollapsing.push_back(rewriter.create<arith::ConstantIndexOp>(
loc, *getConstantIntValue(collapsedOffset)));
}
return indicesAfterCollapsing;
}
namespace {
/// Rewrites contiguous row-major vector.transfer_read ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_read has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
class FlattenContiguousRowMajorTransferReadPattern
: public OpRewritePattern<vector::TransferReadOp> {
public:
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferReadOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
auto source = transferReadOp.getBase();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferReadOp.hasOutOfBoundsDim())
return failure();
if (!transferReadOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferReadOp.getMask())
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
// 2. Generate input args for a new vector.transfer_read that will read
// from the collapsed memref.
// 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
// 2.2 New indices
SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferReadOp.getIndices(), firstDimToCollapse);
// 3. Create new vector.transfer_read that reads from the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
vector::TransferReadOp flatRead = rewriter.create<vector::TransferReadOp>(
loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap);
flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_read with the new one reading from the
// collapsed shape
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transferReadOp, cast<VectorType>(vector.getType()), flatRead);
return success();
}
private:
// Minimum bitwidth that the trailing vector dimension should have after
// flattening.
unsigned targetVectorBitwidth;
};
/// Rewrites contiguous row-major vector.transfer_write ops by inserting
/// memref.collapse_shape on the source so that the resulting
/// vector.transfer_write has a 1D source. Requires the source shape to be
/// already reduced i.e. without unit dims.
///
/// If `targetVectorBitwidth` is provided, the flattening will only happen if
/// the trailing dimension of the vector read is smaller than the provided
/// bitwidth.
class FlattenContiguousRowMajorTransferWritePattern
: public OpRewritePattern<vector::TransferWriteOp> {
public:
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context,
unsigned vectorBitwidth,
PatternBenefit benefit)
: OpRewritePattern<vector::TransferWriteOp>(context, benefit),
targetVectorBitwidth(vectorBitwidth) {}
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
Value source = transferWriteOp.getBase();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// 0. Check pre-conditions
// Contiguity check is valid on tensors only.
if (!sourceType)
return failure();
// If this is already 0D/1D, there's nothing to do.
if (vectorType.getRank() <= 1)
// Already 0D/1D, nothing to do.
return failure();
if (!vectorType.getElementType().isSignlessIntOrFloat())
return failure();
unsigned trailingVectorDimBitwidth =
vectorType.getShape().back() * vectorType.getElementTypeBitWidth();
if (trailingVectorDimBitwidth >= targetVectorBitwidth)
return failure();
if (!vector::isContiguousSlice(sourceType, vectorType))
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
return failure();
if (!transferWriteOp.getPermutationMap().isMinorIdentity())
return failure();
if (transferWriteOp.getMask())
return failure();
int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
// 1. Collapse the source memref
Value collapsedSource =
collapseInnerDims(rewriter, loc, source, firstDimToCollapse);
MemRefType collapsedSourceType =
cast<MemRefType>(collapsedSource.getType());
int64_t collapsedRank = collapsedSourceType.getRank();
assert(collapsedRank == firstDimToCollapse + 1);
// 2. Generate input args for a new vector.transfer_read that will read
// from the collapsed memref.
// 2.1. New dim exprs + affine map
SmallVector<AffineExpr, 1> dimExprs{
getAffineDimExpr(firstDimToCollapse, rewriter.getContext())};
auto collapsedMap =
AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext());
// 2.2 New indices
SmallVector<Value> collapsedIndices =
getCollapsedIndices(rewriter, loc, sourceType.getShape(),
transferWriteOp.getIndices(), firstDimToCollapse);
// 3. Create new vector.transfer_write that writes to the collapsed memref
VectorType flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
Value flatVector =
rewriter.create<vector::ShapeCastOp>(loc, flatVectorType, vector);
vector::TransferWriteOp flatWrite =
rewriter.create<vector::TransferWriteOp>(
loc, flatVector, collapsedSource, collapsedIndices, collapsedMap);
flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true}));
// 4. Replace the old transfer_write with the new one writing the
// collapsed shape
rewriter.eraseOp(transferWriteOp);
return success();
}
private:
// Minimum bitwidth that the trailing vector dimension should have after
// flattening.
unsigned targetVectorBitwidth;
};
/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)`
/// to `memref.load` patterns. The `match` method is shared for both
/// `vector.extract` and `vector.extract_element`.
template <class VectorExtractOp>
class RewriteScalarExtractOfTransferReadBase
: public OpRewritePattern<VectorExtractOp> {
using Base = OpRewritePattern<VectorExtractOp>;
public:
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
: Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
LogicalResult match(VectorExtractOp extractOp) const {
auto xferOp =
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
return failure();
// Check that we are extracting a scalar and not a sub-vector.
if (isa<VectorType>(extractOp.getResult().getType()))
return failure();
// If multiple uses are not allowed, check if xfer has a single use.
if (!allowMultipleUses && !xferOp.getResult().hasOneUse())
return failure();
// If multiple uses are allowed, check if all the xfer uses are extract ops.
if (allowMultipleUses &&
!llvm::all_of(xferOp->getUses(), [](OpOperand &use) {
return isa<vector::ExtractOp, vector::ExtractElementOp>(
use.getOwner());
}))
return failure();
// Mask not supported.
if (xferOp.getMask())
return failure();
// Map not supported.
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Cannot rewrite if the indices may be out of bounds.
if (xferOp.hasOutOfBoundsDim())
return failure();
return success();
}
private:
bool allowMultipleUses;
};
/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
///
/// All the users of the transfer op must be either `vector.extractelement` or
/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
/// transfer ops with any number of users. Otherwise, rewrite only if the
/// extract op is the single user of the transfer op. Rewriting a single
/// vector load with multiple scalar loads may negatively affect performance.
class RewriteScalarExtractElementOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractElementOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();
// Construct scalar load.
auto loc = extractOp.getLoc();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
if (extractOp.getPosition()) {
AffineExpr sym0, sym1;
bindSymbols(extractOp.getContext(), sym0, sym1);
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, sym0 + sym1,
{newIndices[newIndices.size() - 1], extractOp.getPosition()});
if (auto value = dyn_cast<Value>(ofr)) {
newIndices[newIndices.size() - 1] = value;
} else {
newIndices[newIndices.size() - 1] =
rewriter.create<arith::ConstantIndexOp>(loc,
*getConstantIntValue(ofr));
}
}
if (isa<MemRefType>(xferOp.getBase().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
newIndices);
} else {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getBase(), newIndices);
}
return success();
}
};
/// Rewrite `vector.extractelement(vector.transfer_read)` to `memref.load`.
/// Rewrite `vector.extract(vector.transfer_read)` to `memref.load`.
///
/// All the users of the transfer op must be either `vector.extractelement` or
/// `vector.extract` ops. If `allowMultipleUses` is set to true, rewrite
/// transfer ops with any number of users. Otherwise, rewrite only if the
/// extract op is the single user of the transfer op. Rewriting a single
/// vector load with multiple scalar loads may negatively affect performance.
class RewriteScalarExtractOfTransferRead
: public RewriteScalarExtractOfTransferReadBase<vector::ExtractOp> {
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();
// Construct scalar load.
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
xferOp.getIndices().end());
for (auto [i, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
assert(isa<Attribute>(pos) && "Unexpected non-constant index");
int64_t offset = cast<IntegerAttr>(cast<Attribute>(pos)).getInt();
int64_t idx = newIndices.size() - extractOp.getNumIndices() + i;
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(),
rewriter.getAffineSymbolExpr(0) + offset, {newIndices[idx]});
if (auto value = dyn_cast<Value>(ofr)) {
newIndices[idx] = value;
} else {
newIndices[idx] = rewriter.create<arith::ConstantIndexOp>(
extractOp.getLoc(), *getConstantIntValue(ofr));
}
}
if (isa<MemRefType>(xferOp.getBase().getType())) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
newIndices);
} else {
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getBase(), newIndices);
}
return success();
}
};
/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
/// to memref.store.
class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
// Must be a scalar write.
auto vecType = xferOp.getVectorType();
if (!llvm::all_of(vecType.getShape(), [](int64_t sz) { return sz == 1; }))
return failure();
// Mask not supported.
if (xferOp.getMask())
return failure();
// Map not supported.
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Only float and integer element types are supported.
Value scalar =
rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
// Construct a scalar store.
if (isa<MemRefType>(xferOp.getBase().getType())) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
} else {
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
}
return success();
}
};
} // namespace
void mlir::vector::transferOpflowOpt(RewriterBase &rewriter,
Operation *rootOp) {
TransferOptimization opt(rewriter, rootOp);
// Run store to load forwarding first since it can expose more dead store
// opportunity.
rootOp->walk([&](vector::TransferReadOp read) {
if (isa<MemRefType>(read.getShapedType()))
opt.storeToLoadForwarding(read);
});
opt.removeDeadOp();
rootOp->walk([&](vector::TransferWriteOp write) {
if (isa<MemRefType>(write.getShapedType()))
opt.deadStoreOp(write);
});
opt.removeDeadOp();
}
void mlir::vector::populateScalarVectorTransferLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit,
bool allowMultipleUses) {
patterns.add<RewriteScalarExtractElementOfTransferRead,
RewriteScalarExtractOfTransferRead>(patterns.getContext(),
benefit, allowMultipleUses);
patterns.add<RewriteScalarWrite>(patterns.getContext(), benefit);
}
void mlir::vector::populateVectorTransferDropUnitDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns
.add<TransferReadDropUnitDimsPattern, TransferWriteDropUnitDimsPattern>(
patterns.getContext(), benefit);
}
void mlir::vector::populateFlattenVectorTransferPatterns(
RewritePatternSet &patterns, unsigned targetVectorBitwidth,
PatternBenefit benefit) {
patterns.add<FlattenContiguousRowMajorTransferReadPattern,
FlattenContiguousRowMajorTransferWritePattern>(
patterns.getContext(), targetVectorBitwidth, benefit);
populateDropUnitDimWithShapeCastPatterns(patterns, benefit);
}