blob: b801692e9ecc993f4829a433feba5b61a91f5532 [file] [log] [blame]
//===- VectorTransferSplitRewritePatterns.cpp - Transfer Split Rewrites ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements target-independent patterns to rewrite a vector.transfer
// op into a fully in-bounds part and a partial part.
//
//===----------------------------------------------------------------------===//
#include <optional>
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "vector-transfer-split"
using namespace mlir;
using namespace mlir::vector;
/// Build the condition to ensure that a particular VectorTransferOpInterface
/// is in-bounds.
static Value createInBoundsCond(RewriterBase &b,
VectorTransferOpInterface xferOp) {
assert(xferOp.getPermutationMap().isMinorIdentity() &&
"Expected minor identity map");
Value inBoundsCond;
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
// Zip over the resulting vector shape and memref indices.
// If the dimension is known to be in-bounds, it does not participate in
// the construction of `inBoundsCond`.
if (xferOp.isDimInBounds(resultIdx))
return;
// Fold or create the check that `index + vector_size` <= `memref_size`.
Location loc = xferOp.getLoc();
int64_t vectorSize = xferOp.getVectorType().getDimSize(resultIdx);
OpFoldResult sum = affine::makeComposedFoldedAffineApply(
b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
{xferOp.getIndices()[indicesIdx]});
OpFoldResult dimSz =
memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
auto maybeCstSum = getConstantIntValue(sum);
auto maybeCstDimSz = getConstantIntValue(dimSz);
if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
return;
Value cond =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
getValueOrCreateConstantIndexOp(b, loc, sum),
getValueOrCreateConstantIndexOp(b, loc, dimSz));
// Conjunction over all dims for which we are in-bounds.
if (inBoundsCond)
inBoundsCond = b.create<arith::AndIOp>(loc, inBoundsCond, cond);
else
inBoundsCond = cond;
});
return inBoundsCond;
}
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fast path and a slow path.
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accommodate for the fact that the original vector.transfer indexing may
/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
/// At this time, only vector.transfer_read case is implemented.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fast path, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slow path, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.getPermutationMap()` must be a minor identity map
/// 2. the rank of the `xferOp.memref()` and the rank of the
/// `xferOp.getVector()` must be equal. This will be relaxed in the future
/// but requires rank-reducing subviews.
static LogicalResult
splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) {
// TODO: support 0-d corner case.
if (xferOp.getTransferRank() == 0)
return failure();
// TODO: expand support to these 2 cases.
if (!xferOp.getPermutationMap().isMinorIdentity())
return failure();
// Must have some out-of-bounds dimension to be a candidate for splitting.
if (!xferOp.hasOutOfBoundsDim())
return failure();
// Don't split transfer operations directly under IfOp, this avoids applying
// the pattern recursively.
// TODO: improve the filtering condition to make it more applicable.
if (isa<scf::IfOp>(xferOp->getParentOp()))
return failure();
return success();
}
/// Given two MemRefTypes `aT` and `bT`, return a MemRefType to which both can
/// be cast. If the MemRefTypes don't have the same rank or are not strided,
/// return null; otherwise:
/// 1. if `aT` and `bT` are cast-compatible, return `aT`.
/// 2. else return a new MemRefType obtained by iterating over the shape and
/// strides and:
/// a. keeping the ones that are static and equal across `aT` and `bT`.
/// b. using a dynamic shape and/or stride for the dimensions that don't
/// agree.
static MemRefType getCastCompatibleMemRefType(MemRefType aT, MemRefType bT) {
if (memref::CastOp::areCastCompatible(aT, bT))
return aT;
if (aT.getRank() != bT.getRank())
return MemRefType();
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
aStrides.size() != bStrides.size())
return MemRefType();
ArrayRef<int64_t> aShape = aT.getShape(), bShape = bT.getShape();
int64_t resOffset;
SmallVector<int64_t, 4> resShape(aT.getRank(), 0),
resStrides(bT.getRank(), 0);
for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) {
resShape[idx] =
(aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic;
resStrides[idx] =
(aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic;
}
resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic;
return MemRefType::get(
resShape, aT.getElementType(),
StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides));
}
/// Casts the given memref to a compatible memref type. If the source memref has
/// a different address space than the target type, a `memref.memory_space_cast`
/// is first inserted, followed by a `memref.cast`.
static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
MemRefType compatibleMemRefType) {
MemRefType sourceType = cast<MemRefType>(memref.getType());
Value res = memref;
if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) {
sourceType = MemRefType::get(
sourceType.getShape(), sourceType.getElementType(),
sourceType.getLayout(), compatibleMemRefType.getMemorySpace());
res = b.create<memref::MemorySpaceCastOp>(memref.getLoc(), sourceType, res);
}
if (sourceType == compatibleMemRefType)
return res;
return b.create<memref::CastOp>(memref.getLoc(), compatibleMemRefType, res);
}
/// Operates under a scoped context to build the intersection between the
/// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
// TODO: view intersection/union/differences should be a proper std op.
static std::pair<Value, Value>
createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
Value alloc) {
Location loc = xferOp.getLoc();
int64_t memrefRank = xferOp.getShapedType().getRank();
// TODO: relax this precondition, will require rank-reducing subviews.
assert(memrefRank == cast<MemRefType>(alloc.getType()).getRank() &&
"Expected memref rank to match the alloc rank");
ValueRange leadingIndices =
xferOp.getIndices().take_front(xferOp.getLeadingShapedRank());
SmallVector<OpFoldResult, 4> sizes;
sizes.append(leadingIndices.begin(), leadingIndices.end());
auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
xferOp.getSource(), indicesIdx);
Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
Value index = xferOp.getIndices()[indicesIdx];
AffineExpr i, j, k;
bindDims(xferOp.getContext(), i, j, k);
SmallVector<AffineMap, 4> maps =
AffineMap::inferFromExprList(MapList{{i - j, k}}, b.getContext());
// affine_min(%dimMemRef - %index, %dimAlloc)
Value affineMin = b.create<affine::AffineMinOp>(
loc, index.getType(), maps[0], ValueRange{dimMemRef, index, dimAlloc});
sizes.push_back(affineMin);
});
SmallVector<OpFoldResult> srcIndices = llvm::to_vector<4>(llvm::map_range(
xferOp.getIndices(), [](Value idx) -> OpFoldResult { return idx; }));
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
auto copySrc = b.create<memref::SubViewOp>(
loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
auto copyDest = b.create<memref::SubViewOp>(
loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
return std::make_pair(copySrc, copyDest);
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
/// %view = memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
/// %2 = linalg.fill(%pad, %alloc)
/// %3 = subview %view [...][...][...]
/// %4 = subview %alloc [0, 0] [...] [...]
/// linalg.copy(%3, %4)
/// %5 = memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %5, ... : compatibleMemRefType, index, index
/// }
/// ```
/// Return the produced scf::IfOp.
static scf::IfOp
createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
b.create<linalg::FillOp>(loc, ValueRange{xferOp.getPadding()},
ValueRange{alloc});
// Take partial subview of memref which guarantees no dimension
// overflows.
IRRewriter rewriter(b);
std::pair<Value, Value> copyArgs = createSubViewIntersection(
rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
Value casted =
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
b.create<scf::YieldOp>(loc, viewAndIndices);
});
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// (memref.memory_space_cast %A: memref<A..., addr_space> to memref<A...>)
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
/// %2 = vector.transfer_read %view[...], %pad : memref<A...>, vector<...>
/// %3 = vector.type_cast %extra_alloc :
/// memref<...> to memref<vector<...>>
/// store %2, %3[] : memref<vector<...>>
/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4, ... : compatibleMemRefType, index, index
/// }
/// ```
/// Return the produced scf::IfOp.
static scf::IfOp createFullPartialVectorTransferRead(
RewriterBase &b, vector::TransferReadOp xferOp, TypeRange returnTypes,
Value inBoundsCond, MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
scf::IfOp fullPartialIfOp;
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
Operation *newXfer = b.clone(*xferOp.getOperation());
Value vector = cast<VectorTransferOpInterface>(newXfer).getVector();
b.create<memref::StoreOp>(
loc, vector,
b.create<vector::TypeCastOp>(
loc, MemRefType::get({}, vector.getType()), alloc));
Value casted =
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(),
zero);
b.create<scf::YieldOp>(loc, viewAndIndices);
});
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` and a `compatibleMemRefType` have been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// Produce IR resembling:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view, ... : compatibleMemRefType, index, index
/// } else {
/// %3 = vector.type_cast %extra_alloc :
/// memref<...> to memref<vector<...>>
/// %4 = memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4, ... : compatibleMemRefType, index, index
/// }
/// ```
static ValueRange
getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
TypeRange returnTypes, Value inBoundsCond,
MemRefType compatibleMemRefType, Value alloc) {
Location loc = xferOp.getLoc();
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
Value memref = xferOp.getSource();
return b
.create<scf::IfOp>(
loc, inBoundsCond,
[&](OpBuilder &b, Location loc) {
Value res =
castToCompatibleMemRefType(b, memref, compatibleMemRefType);
scf::ValueVector viewAndIndices{res};
llvm::append_range(viewAndIndices, xferOp.getIndices());
b.create<scf::YieldOp>(loc, viewAndIndices);
},
[&](OpBuilder &b, Location loc) {
Value casted =
castToCompatibleMemRefType(b, alloc, compatibleMemRefType);
scf::ValueVector viewAndIndices{casted};
viewAndIndices.insert(viewAndIndices.end(),
xferOp.getTransferRank(), zero);
b.create<scf::YieldOp>(loc, viewAndIndices);
})
->getResults();
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` has been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// 3. it originally wrote to %view
/// Produce IR resembling:
/// ```
/// %notInBounds = arith.xori %inBounds, %true
/// scf.if (%notInBounds) {
/// %3 = subview %alloc [...][...][...]
/// %4 = subview %view [0, 0][...][...]
/// linalg.copy(%3, %4)
/// }
/// ```
static void createFullPartialLinalgCopy(RewriterBase &b,
vector::TransferWriteOp xferOp,
Value inBoundsCond, Value alloc) {
Location loc = xferOp.getLoc();
auto notInBounds = b.create<arith::XOrIOp>(
loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
IRRewriter rewriter(b);
std::pair<Value, Value> copyArgs = createSubViewIntersection(
rewriter, cast<VectorTransferOpInterface>(xferOp.getOperation()),
alloc);
b.create<memref::CopyOp>(loc, copyArgs.first, copyArgs.second);
b.create<scf::YieldOp>(loc, ValueRange{});
});
}
/// Given an `xferOp` for which:
/// 1. `inBoundsCond` has been computed.
/// 2. a memref of single vector `alloc` has been allocated.
/// 3. it originally wrote to %view
/// Produce IR resembling:
/// ```
/// %notInBounds = arith.xori %inBounds, %true
/// scf.if (%notInBounds) {
/// %2 = load %alloc : memref<vector<...>>
/// vector.transfer_write %2, %view[...] : memref<A...>, vector<...>
/// }
/// ```
static void createFullPartialVectorTransferWrite(RewriterBase &b,
vector::TransferWriteOp xferOp,
Value inBoundsCond,
Value alloc) {
Location loc = xferOp.getLoc();
auto notInBounds = b.create<arith::XOrIOp>(
loc, inBoundsCond, b.create<arith::ConstantIntOp>(loc, true, 1));
b.create<scf::IfOp>(loc, notInBounds, [&](OpBuilder &b, Location loc) {
IRMapping mapping;
Value load = b.create<memref::LoadOp>(
loc,
b.create<vector::TypeCastOp>(
loc, MemRefType::get({}, xferOp.getVector().getType()), alloc),
ValueRange());
mapping.map(xferOp.getVector(), load);
b.clone(*xferOp.getOperation(), mapping);
b.create<scf::YieldOp>(loc, ValueRange{});
});
}
// TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
static Operation *getAutomaticAllocationScope(Operation *op) {
// Find the closest surrounding allocation scope that is not a known looping
// construct (putting alloca's in loops doesn't always lower to deallocation
// until the end of the loop).
Operation *scope = nullptr;
for (Operation *parent = op->getParentOp(); parent != nullptr;
parent = parent->getParentOp()) {
if (parent->hasTrait<OpTrait::AutomaticAllocationScope>())
scope = parent;
if (!isa<scf::ForOp, affine::AffineForOp>(parent))
break;
}
assert(scope && "Expected op to be inside automatic allocation scope");
return scope;
}
/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds
/// masking) fastpath and a slowpath.
///
/// For vector.transfer_read:
/// If `ifOp` is not null and the result is `success, the `ifOp` points to the
/// newly created conditional upon function return.
/// To accomodate for the fact that the original vector.transfer indexing may be
/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the
/// scf.if op returns a view and values of type index.
///
/// Example (a 2-D vector.transfer_read):
/// ```
/// %1 = vector.transfer_read %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// // fastpath, direct cast
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// // slowpath, not in-bounds vector.transfer or linalg.copy.
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
// }
/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]}
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// For vector.transfer_write:
/// There are 2 conditional blocks. First a block to decide which memref and
/// indices to use for an unmasked, inbounds write. Then a conditional block to
/// further copy a partial buffer into the final result in the slow path case.
///
/// Example (a 2-D vector.transfer_write):
/// ```
/// vector.transfer_write %arg, %0[...], %pad : memref<A...>, vector<...>
/// ```
/// is transformed into:
/// ```
/// %1:3 = scf.if (%inBounds) {
/// memref.cast %A: memref<A...> to compatibleMemRefType
/// scf.yield %view : compatibleMemRefType, index, index
/// } else {
/// memref.cast %alloc: memref<B...> to compatibleMemRefType
/// scf.yield %4 : compatibleMemRefType, index, index
/// }
/// %0 = vector.transfer_write %arg, %1#0[%1#1, %1#2] {in_bounds = [true ...
/// true]}
/// scf.if (%notInBounds) {
/// // slowpath: not in-bounds vector.transfer or linalg.copy.
/// }
/// ```
/// where `alloc` is a top of the function alloca'ed buffer of one vector.
///
/// Preconditions:
/// 1. `xferOp.getPermutationMap()` must be a minor identity map
/// 2. the rank of the `xferOp.getSource()` and the rank of the
/// `xferOp.getVector()` must be equal. This will be relaxed in the future
/// but requires rank-reducing subviews.
LogicalResult mlir::vector::splitFullAndPartialTransfer(
RewriterBase &b, VectorTransferOpInterface xferOp,
VectorTransformsOptions options, scf::IfOp *ifOp) {
if (options.vectorTransferSplit == VectorTransferSplit::None)
return failure();
SmallVector<bool, 4> bools(xferOp.getTransferRank(), true);
auto inBoundsAttr = b.getBoolArrayAttr(bools);
if (options.vectorTransferSplit == VectorTransferSplit::ForceInBounds) {
b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
return success();
}
// Assert preconditions. Additionally, keep the variables in an inner scope to
// ensure they aren't used in the wrong scopes further down.
{
assert(succeeded(splitFullAndPartialTransferPrecondition(xferOp)) &&
"Expected splitFullAndPartialTransferPrecondition to hold");
auto xferReadOp = dyn_cast<vector::TransferReadOp>(xferOp.getOperation());
auto xferWriteOp = dyn_cast<vector::TransferWriteOp>(xferOp.getOperation());
if (!(xferReadOp || xferWriteOp))
return failure();
if (xferWriteOp && xferWriteOp.getMask())
return failure();
if (xferReadOp && xferReadOp.getMask())
return failure();
}
RewriterBase::InsertionGuard guard(b);
b.setInsertionPoint(xferOp);
Value inBoundsCond = createInBoundsCond(
b, cast<VectorTransferOpInterface>(xferOp.getOperation()));
if (!inBoundsCond)
return failure();
// Top of the function `alloc` for transient storage.
Value alloc;
{
RewriterBase::InsertionGuard guard(b);
Operation *scope = getAutomaticAllocationScope(xferOp);
assert(scope->getNumRegions() == 1 &&
"AutomaticAllocationScope with >1 regions");
b.setInsertionPointToStart(&scope->getRegion(0).front());
auto shape = xferOp.getVectorType().getShape();
Type elementType = xferOp.getVectorType().getElementType();
alloc = b.create<memref::AllocaOp>(scope->getLoc(),
MemRefType::get(shape, elementType),
ValueRange{}, b.getI64IntegerAttr(32));
}
MemRefType compatibleMemRefType =
getCastCompatibleMemRefType(cast<MemRefType>(xferOp.getShapedType()),
cast<MemRefType>(alloc.getType()));
if (!compatibleMemRefType)
return failure();
SmallVector<Type, 4> returnTypes(1 + xferOp.getTransferRank(),
b.getIndexType());
returnTypes[0] = compatibleMemRefType;
if (auto xferReadOp =
dyn_cast<vector::TransferReadOp>(xferOp.getOperation())) {
// Read case: full fill + partial copy -> in-bounds vector.xfer_read.
scf::IfOp fullPartialIfOp =
options.vectorTransferSplit == VectorTransferSplit::VectorTransfer
? createFullPartialVectorTransferRead(b, xferReadOp, returnTypes,
inBoundsCond,
compatibleMemRefType, alloc)
: createFullPartialLinalgCopy(b, xferReadOp, returnTypes,
inBoundsCond, compatibleMemRefType,
alloc);
if (ifOp)
*ifOp = fullPartialIfOp;
// Set existing read op to in-bounds, it always reads from a full buffer.
for (unsigned i = 0, e = returnTypes.size(); i != e; ++i)
xferReadOp.setOperand(i, fullPartialIfOp.getResult(i));
b.modifyOpInPlace(xferOp, [&]() {
xferOp->setAttr(xferOp.getInBoundsAttrName(), inBoundsAttr);
});
return success();
}
auto xferWriteOp = cast<vector::TransferWriteOp>(xferOp.getOperation());
// Decide which location to write the entire vector to.
auto memrefAndIndices = getLocationToWriteFullVec(
b, xferWriteOp, returnTypes, inBoundsCond, compatibleMemRefType, alloc);
// Do an in bounds write to either the output or the extra allocated buffer.
// The operation is cloned to prevent deleting information needed for the
// later IR creation.
IRMapping mapping;
mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
auto *clone = b.clone(*xferWriteOp, mapping);
clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
// Create a potential copy from the allocated buffer to the final output in
// the slow path case.
if (options.vectorTransferSplit == VectorTransferSplit::VectorTransfer)
createFullPartialVectorTransferWrite(b, xferWriteOp, inBoundsCond, alloc);
else
createFullPartialLinalgCopy(b, xferWriteOp, inBoundsCond, alloc);
b.eraseOp(xferOp);
return success();
}
namespace {
/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern
/// may take an extra filter to perform selection at a finer granularity.
struct VectorTransferFullPartialRewriter : public RewritePattern {
using FilterConstraintType =
std::function<LogicalResult(VectorTransferOpInterface op)>;
explicit VectorTransferFullPartialRewriter(
MLIRContext *context,
VectorTransformsOptions options = VectorTransformsOptions(),
FilterConstraintType filter =
[](VectorTransferOpInterface op) { return success(); },
PatternBenefit benefit = 1)
: RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options),
filter(std::move(filter)) {}
/// Performs the rewrite.
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override;
private:
VectorTransformsOptions options;
FilterConstraintType filter;
};
} // namespace
LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite(
Operation *op, PatternRewriter &rewriter) const {
auto xferOp = dyn_cast<VectorTransferOpInterface>(op);
if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) ||
failed(filter(xferOp)))
return failure();
return splitFullAndPartialTransfer(rewriter, xferOp, options);
}
void mlir::vector::populateVectorTransferFullPartialPatterns(
RewritePatternSet &patterns, const VectorTransformsOptions &options) {
patterns.add<VectorTransferFullPartialRewriter>(patterns.getContext(),
options);
}