blob: 01632c6ea1579d7c126a2e21219a68f20ff50a10 [file]
//===-ElideReinterpretCast.cpp - Expansion patterns for MemRef operations-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Repeated.h"
#include <cassert>
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_ELIDEREINTERPRETCASTPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
namespace {
/// Returns true if `rc` represents a scalar view (all sizes == 1)
/// into a memref that has exactly one non-unit dimension located at
/// either the first or last position (i.e. a "row" or "column").
///
/// Examples that return true:
///
/// // Row-major slice (last dim is non-unit)
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [1, 1, 1], strides: [1, 1, 1]
/// : memref<1x1x8xi32> to memref<1x1x1xi32>
///
/// // Column-major slice (first dim is non-unit)
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [1, 1], strides: [1, 1]
/// : memref<2x1xf32> to memref<1x1xf32>
///
/// // Random strides
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [1, 1], strides: [10, 100]
/// : memref<2x1xf32, strided<[10, 100]>>
/// to memref<1x1xf32>
///
/// // Rank-1 case
/// memref.reinterpret_cast %buf to offset: [%off],
/// sizes: [1], strides: [1]
/// : memref<8xi32> to memref<1xi32>
///
/// Examples that return false:
///
/// // More non-unit dims
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [1, 1, 1], strides: [1, 1, 1]
/// : memref<1x2x8xi32> to memref<1x1x1xi32>
///
/// // View is not scalar (size != 1)
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [2, 1], strides: [1, 1]
/// : memref<1x2xf32> to memref<2x1xf32>
///
/// // Base has non-identity layout
/// %buff = memref.alloc() : memref<1x2xf32, strided<[1, 3]>>
/// memref.reinterpret_cast %buff to offset: [%off],
/// sizes: [1, 1], strides: [1, 1]
/// : memref<1x2xf32, strided<[1, 3]>> to memref<1x1xf32>
static bool isScalarSlice(memref::ReinterpretCastOp rc) {
auto rcInputTy = dyn_cast<MemRefType>(rc.getSource().getType());
auto rcOutputTy = dyn_cast<MemRefType>(rc.getType());
// Reject strided base - logic for computing linear idx is TODO
if (!rcInputTy.getLayout().isIdentity())
return false;
// Reject non-matching ranks
unsigned srcRank = rcInputTy.getRank();
if (srcRank != rcOutputTy.getRank())
return false;
ArrayRef<int64_t> sizes = rc.getStaticSizes();
// View must be scalar: memref<1x...x1>
if (!llvm::all_of(rcOutputTy.getShape(),
[](int64_t dim) { return dim == 1; }))
return false;
// Sizes must all be statically 1
if (!llvm::all_of(sizes, [](int64_t size) {
return !ShapedType::isDynamic(size) && size == 1;
}))
return false;
// Rank-1 special case
if (srcRank == 1) {
// Reject non-scalar output
if (rcOutputTy.getDimSize(0) > 1)
return false;
}
int nonUnitCount =
std::count_if(rcInputTy.getShape().begin(), rcInputTy.getShape().end(),
[](int dim) { return dim != 1; });
return nonUnitCount == 1;
}
/// Rewrites `memref.copy` of a 1-element MemRef as a scalar load-store pair
///
/// The pattern matches a reinterpret_cast that creates a scalar view
/// (`sizes = [1, ..., 1]`) into a memref with a single non-unit dimension.
/// Since the view contains only one element, the accessed address is
/// determined solely by the base pointer and the offset.
///
/// Two layouts are supported:
/// * row-major slice (stride pattern [N, ..., 1])
/// * column-major slice (stride pattern [1, ..., N])
///
/// BEFORE (row-major slice)
/// %view = memref.reinterpret_cast %base
/// to offset: [%off], sizes: [1, ..., 1], strides: [N, ..., 1]
/// : memref<1x...xNxf32>
/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
/// memref.copy %src, %view
/// : memref<1x...x1xf32>
/// to memref<1x...x1xf32, strided<[N, ..., 1], offset: ?>>
///
/// AFTER
/// %c0 = arith.constant 0 : index
/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
/// memref.store %v, %base[%c0, ..., %off] : memref<1x...xNxf32>
///
/// BEFORE (column-major slice)
/// %view = memref.reinterpret_cast %base
/// to offset: [%off], sizes: [1, ..., 1], strides: [1, ..., N]
/// : memref<Nx...x1xf32>
/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
/// memref.copy %src, %view
/// : memref<1x...x1xf32>
/// to memref<1x...x1xf32, strided<[1, ..., N], offset: ?>>
///
/// AFTER
/// %c0 = arith.constant 0 : index
/// %v = memref.load %src[%c0, ..., %c0] : memref<1x...x1xf32>
/// memref.store %v, %base[%off, ..., %c0] : memref<Nx...x1xf32>
struct CopyToScalarLoadAndStore : public OpRewritePattern<memref::CopyOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(memref::CopyOp op,
PatternRewriter &rewriter) const final {
Value rcOutput = op.getTarget();
auto rc = rcOutput.getDefiningOp<memref::ReinterpretCastOp>();
if (!rc)
return rewriter.notifyMatchFailure(
op, "target is not a memref.reinterpret_cast");
if (!isScalarSlice(rc))
return rewriter.notifyMatchFailure(
op, "reinterpret_cast does not match scalar slice");
Location loc = op.getLoc();
Value src = op.getSource();
Value dst = rc.getSource();
auto dstType = cast<MemRefType>(dst.getType());
unsigned dstRank = dstType.getRank();
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto srcType = cast<MemRefType>(src.getType());
Repeated<Value> loadIndices(srcType.getRank(), zero);
auto offsets = rc.getMixedOffsets();
assert(offsets.size() == 1 && "Expecting single offset");
OpFoldResult offset = offsets[0];
Value storeOffset = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
unsigned offsetDim = dstType.getDimSize(0) == 1 ? dstRank - 1 : 0;
SmallVector<Value> storeIndices(dstRank, zero);
storeIndices[offsetDim] = storeOffset;
// If the only user of `rc` is the current Op (which is about to be erased),
// we can safely erase it.
if (rcOutput.hasOneUse())
rewriter.eraseOp(rc);
Value val = memref::LoadOp::create(rewriter, loc, src, loadIndices);
memref::StoreOp::create(rewriter, loc, val, dst, storeIndices);
rewriter.eraseOp(op);
return success();
}
};
struct ElideReinterpretCastPass
: public memref::impl::ElideReinterpretCastPassBase<
ElideReinterpretCastPass> {
void runOnOperation() override {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
memref::populateElideReinterpretCastPatterns(patterns);
ConversionTarget target(ctx);
target.addDynamicallyLegalOp<memref::CopyOp>([](memref::CopyOp op) {
auto rc = op.getTarget().getDefiningOp<memref::ReinterpretCastOp>();
if (!rc)
return true;
return !isScalarSlice(rc);
});
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
} // namespace
void mlir::memref::populateElideReinterpretCastPatterns(
RewritePatternSet &patterns) {
patterns.add<CopyToScalarLoadAndStore>(patterns.getContext());
}