blob: 166f09eaa8762bcddbb8785a2f4924de149a6a71 [file] [log] [blame]
//===- VectorInsertExtractStridedSliceRewritePatterns.cpp - Rewrites ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
using namespace mlir::vector;
// Helper that picks the proper sequence for inserting.
static Value insertOne(PatternRewriter &rewriter, Location loc, Value from,
Value into, int64_t offset) {
auto vectorType = into.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<InsertOp>(loc, from, into, offset);
return rewriter.create<vector::InsertElementOp>(
loc, vectorType, from, into,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
// Helper that picks the proper sequence for extracting.
static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector,
int64_t offset) {
auto vectorType = vector.getType().cast<VectorType>();
if (vectorType.getRank() > 1)
return rewriter.create<ExtractOp>(loc, vector, offset);
return rewriter.create<vector::ExtractElementOp>(
loc, vectorType.getElementType(), vector,
rewriter.create<arith::ConstantIndexOp>(loc, offset));
}
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have different ranks.
///
/// When ranks are different, InsertStridedSlice needs to extract a properly
/// ranked vector from the destination vector into which to insert. This pattern
/// only takes care of this extraction part and forwards the rest to
/// [VectorInsertStridedSliceOpSameRankRewritePattern].
///
/// For a k-D source and n-D destination vector (k < n), we emit:
/// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to
/// insert the k-D source.
/// 2. k-D -> (n-1)-D InsertStridedSlice op
/// 3. InsertOp that is the reverse of 1.
class VectorInsertStridedSliceOpDifferentRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
auto loc = op.getLoc();
int64_t rankDiff = dstType.getRank() - srcType.getRank();
assert(rankDiff >= 0);
if (rankDiff == 0)
return failure();
int64_t rankRest = dstType.getRank() - rankDiff;
// Extract / insert the subvector of matching rank and InsertStridedSlice
// on it.
Value extracted =
rewriter.create<ExtractOp>(loc, op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
// A different pattern will kick in for InsertStridedSlice with matching
// ranks.
auto stridedSliceInnerOp = rewriter.create<InsertStridedSliceOp>(
loc, op.source(), extracted,
getI64SubArray(op.offsets(), /*dropFront=*/rankDiff),
getI64SubArray(op.strides(), /*dropFront=*/0));
rewriter.replaceOpWithNewOp<InsertOp>(
op, stridedSliceInnerOp.getResult(), op.dest(),
getI64SubArray(op.offsets(), /*dropFront=*/0,
/*dropBack=*/rankRest));
return success();
}
};
/// RewritePattern for InsertStridedSliceOp where source and destination vectors
/// have the same rank. For each outermost index in the slice:
/// begin end stride
/// [offset : offset+size*stride : stride]
/// 1. ExtractOp one (k-1)-D source subvector and one (n-1)-D dest subvector.
/// 2. InsertStridedSlice (k-1)-D into (n-1)-D
/// 3. the destination subvector is inserted back in the proper place
/// 3. InsertOp that is the reverse of 1.
class VectorInsertStridedSliceOpSameRankRewritePattern
: public OpRewritePattern<InsertStridedSliceOp> {
public:
using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
// bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto srcType = op.getSourceVectorType();
auto dstType = op.getDestVectorType();
if (op.offsets().getValue().empty())
return failure();
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
assert(dstRank >= srcRank);
if (dstRank != srcRank)
return failure();
if (srcType == dstType) {
rewriter.replaceOp(op, op.source());
return success();
}
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = srcType.getShape().front();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
Value res = op.dest();
if (srcRank == 1) {
int nSrc = srcType.getShape().front();
int nDest = dstType.getShape().front();
// 1. Scale source to destType so we can shufflevector them together.
SmallVector<int64_t> offsets(nDest, 0);
for (int64_t i = 0; i < nSrc; ++i)
offsets[i] = i;
Value scaledSource =
rewriter.create<ShuffleOp>(loc, op.source(), op.source(), offsets);
// 2. Create a mask where we take the value from scaledSource of dest
// depending on the offset.
offsets.clear();
for (int64_t i = 0, e = offset + size * stride; i < nDest; ++i) {
if (i < offset || i >= e || (i - offset) % stride != 0)
offsets.push_back(nDest + i);
else
offsets.push_back((i - offset) / stride);
}
// 3. Replace with a ShuffleOp.
rewriter.replaceOpWithNewOp<ShuffleOp>(op, scaledSource, op.dest(),
offsets);
return success();
}
// For each slice of the source vector along the most major dimension.
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
// 1. extract the proper subvector (or element) from source
Value extractedSource = extractOne(rewriter, loc, op.source(), idx);
if (extractedSource.getType().isa<VectorType>()) {
// 2. If we have a vector, extract the proper subvector from destination
// Otherwise we are at the element level and no need to recurse.
Value extractedDest = extractOne(rewriter, loc, op.dest(), off);
// 3. Reduce the problem to lowering a new InsertStridedSlice op with
// smaller rank.
extractedSource = rewriter.create<InsertStridedSliceOp>(
loc, extractedSource, extractedDest,
getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
}
// 4. Insert the extractedSource into the res vector.
res = insertOne(rewriter, loc, extractedSource, res, off);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// Progressive lowering of ExtractStridedSliceOp to either:
/// 1. single offset extract as a direct vector::ShuffleOp.
/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp +
/// InsertOp/InsertElementOp for the n-D case.
class VectorExtractStridedSliceOpRewritePattern
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
// is bounded as the rank is strictly decreasing.
setHasBoundedRewriteRecursion();
}
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
auto dstType = op.getType();
assert(!op.offsets().getValue().empty() && "Unexpected empty offsets");
int64_t offset =
op.offsets().getValue().front().cast<IntegerAttr>().getInt();
int64_t size = op.sizes().getValue().front().cast<IntegerAttr>().getInt();
int64_t stride =
op.strides().getValue().front().cast<IntegerAttr>().getInt();
auto loc = op.getLoc();
auto elemType = dstType.getElementType();
assert(elemType.isSignlessIntOrIndexOrFloat());
// Single offset can be more efficiently shuffled.
if (op.offsets().getValue().size() == 1) {
SmallVector<int64_t, 4> offsets;
offsets.reserve(size);
for (int64_t off = offset, e = offset + size * stride; off < e;
off += stride)
offsets.push_back(off);
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.vector(),
op.vector(),
rewriter.getI64ArrayAttr(offsets));
return success();
}
// Extract/insert on a lower ranked extract strided slice op.
Value zero = rewriter.create<arith::ConstantOp>(
loc, elemType, rewriter.getZeroAttr(elemType));
Value res = rewriter.create<SplatOp>(loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = extractOne(rewriter, loc, op.vector(), off);
Value extracted = rewriter.create<ExtractStridedSliceOp>(
loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1),
getI64SubArray(op.sizes(), /* dropFront=*/1),
getI64SubArray(op.strides(), /* dropFront=*/1));
res = insertOne(rewriter, loc, extracted, res, idx);
}
rewriter.replaceOp(op, res);
return success();
}
};
/// Populate the given list with patterns that convert from Vector to LLVM.
void mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
RewritePatternSet &patterns) {
patterns.add<VectorInsertStridedSliceOpDifferentRankRewritePattern,
VectorInsertStridedSliceOpSameRankRewritePattern,
VectorExtractStridedSliceOpRewritePattern>(
patterns.getContext());
}