blob: 1feda57d8de0366c67fe7e77ebcfd2a196107c64 [file] [log] [blame]
//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass folds loading/storing from/to subview ops into
// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Given the 'indices' of an load/store operation where the memref is a result
/// of a subview op, returns the indices w.r.t to the source memref of the
/// subview op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
/// memref<4x4xf32, offset=?, strides=[?, ?]>
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
///
/// could be folded into
///
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
/// memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
memref::SubViewOp subViewOp, ValueRange indices,
SmallVectorImpl<Value> &sourceIndices) {
SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
SmallVector<Value> useIndices;
// Check if this is rank-reducing case. Then for every unit-dim size add a
// zero to the indices.
unsigned resultDim = 0;
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
if (unusedDims.count(dim))
useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
else
useIndices.push_back(indices[resultDim++]);
}
if (useIndices.size() != mixedOffsets.size())
return failure();
sourceIndices.resize(useIndices.size());
for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
SmallVector<Value> dynamicOperands;
AffineExpr expr = rewriter.getAffineDimExpr(0);
unsigned numSymbols = 0;
dynamicOperands.push_back(useIndices[index]);
// Multiply the stride;
if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
expr = expr * attr.cast<IntegerAttr>().getInt();
} else {
dynamicOperands.push_back(mixedStrides[index].get<Value>());
expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
}
// Add the offset.
if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
expr = expr + attr.cast<IntegerAttr>().getInt();
} else {
dynamicOperands.push_back(mixedOffsets[index].get<Value>());
expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
}
Location loc = subViewOp.getLoc();
sourceIndices[index] = rewriter.create<AffineApplyOp>(
loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
}
return success();
}
/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.source();
}
/// Given the permutation map of the original
/// `vector.transfer_read`/`vector.transfer_write` operations compute the
/// permutation map to use after the subview is folded with it.
static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
memref::SubViewOp subViewOp,
AffineMap currPermutationMap) {
llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
SmallVector<AffineExpr> exprs;
int64_t sourceRank = subViewOp.getSourceType().getRank();
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
if (unusedDims.count(dim))
continue;
exprs.push_back(getAffineDimExpr(dim, context));
}
auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
return AffineMapAttr::get(
currPermutationMap.compose(resultDimToSourceDimMap));
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
private:
void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};
/// Merges subview operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
private:
void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices,
PatternRewriter &rewriter) const;
};
template <>
void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
memref::LoadOp loadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
sourceIndices);
}
template <>
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
// TODO: support 0-d corner case.
if (transferReadOp.getTransferRank() == 0)
return;
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
sourceIndices,
getPermutationMapAttr(rewriter.getContext(), subViewOp,
transferReadOp.permutation_map()),
transferReadOp.padding(),
/*mask=*/Value(), transferReadOp.in_boundsAttr());
}
template <>
void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
memref::StoreOp storeOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
}
template <>
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
// TODO: support 0-d corner case.
if (transferWriteOp.getTransferRank() == 0)
return;
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
sourceIndices,
getPermutationMapAttr(rewriter.getContext(), subViewOp,
transferWriteOp.permutation_map()),
transferWriteOp.in_boundsAttr());
}
} // namespace
template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
return failure();
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
loadOp.indices(), sourceIndices)))
return failure();
replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
return success();
}
template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
return failure();
SmallVector<Value, 4> sourceIndices;
if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
storeOp.indices(), sourceIndices)))
return failure();
replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
return success();
}
void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
struct FoldSubViewOpsPass final
: public FoldSubViewOpsBase<FoldSubViewOpsPass> {
void runOnOperation() override;
};
} // namespace
void FoldSubViewOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateFoldSubViewOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
return std::make_unique<FoldSubViewOpsPass>();
}