blob: e0d2e6d00b9f0104d817cec212d32c0a7ed8b8a3 [file] [edit]
//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass folds loading/storing from/to subview ops into
// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#define DEBUG_TYPE "fold-memref-alias-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
/// check if they all contain exactly one dimension to collape/expand into.
static bool
hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
int64_t n) {
if (n <= 0)
return true;
if (n > static_cast<int64_t>(reassocs.size()))
return false;
return llvm::all_of(
reassocs.take_back(n),
[&](const ReassociationIndices &indices) { return indices.size() == 1; });
}
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
if (n <= 0)
return true;
ArrayRef<int64_t> strides = subview.getStaticStrides();
if (n > static_cast<int64_t>(strides.size()))
return false;
return llvm::all_of(strides.take_back(n), [](int64_t s) { return s == 1; });
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Folds subview(subview(x)) to a single subview(x).
class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
public:
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subView,
PatternRewriter &rewriter) const override {
auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
if (!srcSubView)
return failure();
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
if (failed(affine::mergeOffsetsSizesAndStrides(
rewriter, subView.getLoc(), srcSubView, subView,
srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
return failure();
// Replace original op.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
subView, subView.getType(), srcSubView.getSource(), newOffsets,
newSizes, newStrides);
return success();
}
};
/// Merges subview operations with load/store like operations unless such a
/// merger would cause the strides between dimensions accessed by that operaton
/// to change.
struct AccessOpOfSubViewOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merge a memref.expand_shape operation with an operation that accesses a
/// memref by index unless that operation accesss more than one dimension of
/// memory and any dimension other than the outermost dimension accessed this
/// way would be merged. This prevents issuses from arising with, say, a
/// vector.load of a 4x2 vector having the two trailing dimensions of the access
/// get merged.
struct AccessOpOfExpandShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges an operation that accesses a memref by index with a
/// memref.collapse_shape, unless this would break apart a dimension other than
/// the outermost one that an operation accesses. This prevents, for example,
/// transforming a load of a 3x8 vector from a 6x8 memref into a load
/// from a 3x4x2 memref (as this would require special handling and could lead
/// to invalid IR if that higher-dimensional memref comes from a subview) but
/// does permit turning a load of a length-8 vector from a 3x8 memref into a
/// load from a 3x2x8 one.
struct AccessOpOfCollapseShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.subview operations present on the source or destination
/// operands of indexed memory copy operations (DMA operations) into those
/// operations. This is perfromed unconditionally, since folding in a subview
/// cannot change the starting position of the copy, which is what the
/// memref/index pair represent in DMA operations.
struct IndexedMemCopyOpOfSubViewOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.expand_shape operations that are present on the source or
/// destination of an indexed memory copy/DMA into the memref/index arguments of
/// that DMA. As with subviews, this can be done unconditionally.
struct IndexedMemCopyOpOfExpandShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.collapse_shape operations that are present on the source or
/// destination of an indexed memory copy/DMA into the memref/index arguments of
/// that DMA. As with subviews, this can be done unconditionally.
struct IndexedMemCopyOpOfCollapseShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.subview ops on the base argument to vector transfer operations
/// into the base and indices of that transfer if:
/// - The subview has unit strides on transfer dimensions
/// - All the transfer dimensions are in-bounds
/// This will correctly update said permutation map to account for dropped
/// dimensions in rank-reducing subviews.
struct TransferOpOfSubViewOpFolder final
: OpInterfaceRewritePattern<VectorTransferOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(VectorTransferOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.expand_shape ops that create the base of a vector transfer
/// operation into the base and indices of that transfer. Does not act when the
/// a dimension is potentially out of bounds, if one of the transfer dimensions
/// would need to be strided because of the collapse, or if it would merge two
/// dimensions that are both transfer dimensions.
/// TODO: become more sophisticated about length-1 dimensions that are the
/// result of an expansion becoming broadcasts.
struct TransferOpOfExpandShapeOpFolder final
: OpInterfaceRewritePattern<VectorTransferOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(VectorTransferOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.collapse_shape ops that create the base of a vector transfer
/// operation into the base and indices of that transfer. Does not act when the
/// permutation map is not trivial, a dimension could be performing out of
/// bounds reads, or if it would break apart a transfer dimension.
struct TransferOpOfCollapseShapeOpFolder final
: OpInterfaceRewritePattern<VectorTransferOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(VectorTransferOpInterface op,
PatternRewriter &rewriter) const override;
};
} // namespace
LogicalResult
AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const {
TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
if (!accessedMemref)
return rewriter.notifyMatchFailure(op, "not accessing a memref");
auto subview = accessedMemref.getDefiningOp<memref::SubViewOp>();
if (!subview)
return rewriter.notifyMatchFailure(op, "not accessing a subview");
SmallVector<int64_t> accessedShape = op.getAccessedShape();
// Note the subtle difference between accessedShape = {1} and accessedShape =
// {} here. The former prevents us from folding in a subview that doesn't
// have a unit stride on the final dimension, while the latter does not (since
// it indexes scalar accesses).
int64_t accessedDims = accessedShape.size();
if (!hasTrailingUnitStrides(subview, accessedDims))
return rewriter.notifyMatchFailure(
op, "non-unit stride on accessed dimensions");
llvm::SmallBitVector droppedDims = subview.getDroppedDims();
int64_t sourceRank = subview.getSourceType().getRank();
// Ignore outermost access dimension - we only care about dropped dimensions
// between the accessed op's results, as those could break the accessing op's
// semantics.
int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
if (secondAccessedDim < sourceRank) {
for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
if (droppedDims.test(d))
return rewriter.notifyMatchFailure(
op, "reintroducing dropped dimension " + Twine(d) +
" would break access op semantics");
}
}
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), subview.getMixedOffsets(),
subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
std::optional<SmallVector<Value>> newValues =
op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
if (!accessedMemref)
return rewriter.notifyMatchFailure(op, "not accessing a memref");
auto expand = accessedMemref.getDefiningOp<memref::ExpandShapeOp>();
if (!expand)
return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
ArrayRef<int64_t> accessedShape = rawAccessedShape;
if (expand.getSrcType().getRank() <
static_cast<int64_t>(accessedShape.size()))
return rewriter.notifyMatchFailure(
op, "expand_shape source rank is too small for the accessed shape");
// Cut off the leading dimension, since we don't care about modifying its
// strides.
if (!accessedShape.empty())
accessedShape = accessedShape.drop_front();
SmallVector<ReassociationIndices, 4> reassocs =
expand.getReassociationIndices();
if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
return rewriter.notifyMatchFailure(
op,
"expand_shape folding would merge semantically important dimensions");
SmallVector<Value> sourceIndices;
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
op.getIndices(), sourceIndices,
op.hasInboundsIndices());
std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
rewriter, expand.getViewSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
TypedValue<MemRefType> accessedMemref = op.getAccessedMemref();
if (!accessedMemref)
return rewriter.notifyMatchFailure(op, "not accessing a memref");
auto collapse = accessedMemref.getDefiningOp<memref::CollapseShapeOp>();
if (!collapse)
return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
ArrayRef<int64_t> accessedShape = rawAccessedShape;
if (collapse.getSrcType().getRank() <
static_cast<int64_t>(accessedShape.size()))
return rewriter.notifyMatchFailure(
op, "collapse_shape source rank is too small for the accessed shape");
// Cut off the leading dimension, since we don't care about its strides being
// modified and we know that the dimensions within its reassociation group, if
// it's non-trivial, must be contiguous.
if (!accessedShape.empty())
accessedShape = accessedShape.drop_front();
SmallVector<ReassociationIndices, 4> reassocs =
collapse.getReassociationIndices();
if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
return rewriter.notifyMatchFailure(op, "collapse_shape folding would merge "
"semantically important dimensions");
SmallVector<Value> sourceIndices;
memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
op.getIndices(), sourceIndices,
op.hasInboundsIndices());
std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
rewriter, collapse.getViewSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
TypedValue<MemRefType> src = op.getSrc();
TypedValue<MemRefType> dst = op.getDst();
auto srcSubview = src ? src.getDefiningOp<memref::SubViewOp>() : nullptr;
auto dstSubview = dst ? dst.getDefiningOp<memref::SubViewOp>() : nullptr;
if (!srcSubview && !dstSubview)
return rewriter.notifyMatchFailure(
op, "no subviews found on indexed copy inputs");
Value newSrc = src;
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = dst;
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcSubview) {
newSrc = srcSubview.getSource();
newSrcIndices.clear();
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
op.getSrcIndices(), newSrcIndices);
}
if (dstSubview) {
newDst = dstSubview.getSource();
newDstIndices.clear();
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
op.getDstIndices(), newDstIndices);
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
TypedValue<MemRefType> src = op.getSrc();
TypedValue<MemRefType> dst = op.getDst();
auto srcExpand = src ? src.getDefiningOp<memref::ExpandShapeOp>() : nullptr;
auto dstExpand = dst ? dst.getDefiningOp<memref::ExpandShapeOp>() : nullptr;
if (!srcExpand && !dstExpand)
return rewriter.notifyMatchFailure(
op, "no expand_shapes found on indexed copy inputs");
Value newSrc = src;
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = dst;
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcExpand) {
newSrc = srcExpand.getViewSource();
newSrcIndices.clear();
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
op.getSrcIndices(), newSrcIndices,
op.hasInboundsSrcIndices());
}
if (dstExpand) {
newDst = dstExpand.getViewSource();
newDstIndices.clear();
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
op.getDstIndices(), newDstIndices,
op.hasInboundsDstIndices());
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
TypedValue<MemRefType> src = op.getSrc();
TypedValue<MemRefType> dst = op.getDst();
auto srcCollapse =
src ? src.getDefiningOp<memref::CollapseShapeOp>() : nullptr;
auto dstCollapse =
dst ? dst.getDefiningOp<memref::CollapseShapeOp>() : nullptr;
if (!srcCollapse && !dstCollapse)
return rewriter.notifyMatchFailure(
op, "no collapse_shapes found on indexed copy inputs");
Value newSrc = src;
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = dst;
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcCollapse) {
newSrc = srcCollapse.getViewSource();
newSrcIndices.clear();
memref::resolveSourceIndicesCollapseShape(
op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices,
op.hasInboundsSrcIndices());
}
if (dstCollapse) {
newDst = dstCollapse.getViewSource();
newDstIndices.clear();
memref::resolveSourceIndicesCollapseShape(
op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices,
op.hasInboundsDstIndices());
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult
TransferOpOfSubViewOpFolder::matchAndRewrite(VectorTransferOpInterface op,
PatternRewriter &rewriter) const {
auto subview = op.getBase().getDefiningOp<memref::SubViewOp>();
if (!subview)
return rewriter.notifyMatchFailure(op, "not accessing a subview");
AffineMap perm = op.getPermutationMap();
// Note: no identity permutation check here, since subview folding can handle
// complex permutations because it doesn't merge or split any individual
// dimension.
if (op.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "out of bounds dimension");
VectorType vecTy = op.getVectorType();
// Because we know the permutation map is a minor identity, we know that the
// last N dimensions must have unit stride, where N is the vector rank.
if (!hasTrailingUnitStrides(subview, vecTy.getRank()))
return rewriter.notifyMatchFailure(subview, "non-unit stride within last " +
Twine(vecTy.getRank()) +
" dimensions");
AffineMap newPerm = expandDimsToRank(perm, subview.getSourceType().getRank(),
subview.getDroppedDims());
if (failed(op.mayUpdateStartingPosition(subview.getSourceType(), newPerm)))
return rewriter.notifyMatchFailure(subview,
"failed op-specific preconditions");
SmallVector<Value> newIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), subview.getMixedOffsets(),
subview.getMixedStrides(), subview.getDroppedDims(), op.getIndices(),
newIndices);
op.updateStartingPosition(rewriter, subview.getSource(), newIndices,
AffineMapAttr::get(newPerm));
return success();
}
LogicalResult TransferOpOfExpandShapeOpFolder::matchAndRewrite(
VectorTransferOpInterface op, PatternRewriter &rewriter) const {
auto expand = op.getBase().getDefiningOp<memref::ExpandShapeOp>();
if (!expand)
return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
if (op.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "out of bounds dimension");
int64_t srcRank = expand.getSrc().getType().getRank();
int64_t vecRank = op.getVectorType().getRank();
if (srcRank < vecRank)
return rewriter.notifyMatchFailure(op,
"source rank is less than vector rank");
llvm::SmallDenseMap<int64_t, int64_t, 8> unstridedResDimToSrcDim;
for (auto [srcIdx, reassoc] :
llvm::enumerate(expand.getReassociationIndices())) {
unstridedResDimToSrcDim.insert({reassoc.back(), srcIdx});
}
// If every dimension of the expanded shape that appears in the permutation
// map is also present in the final entry of the expansions (meaning that
// collapsing in more values won't cause us to need to stride the index), we
// can fold in the expansion. (This doesn't currently account for expanding
// length X to X by 1, but it could in the future).
AffineMap permMap = op.getPermutationMap();
SmallVector<AffineExpr> newPermMapResults;
newPermMapResults.reserve(permMap.getNumResults());
for (AffineExpr permRes : permMap.getResults()) {
auto resDim = dyn_cast<AffineDimExpr>(permRes);
if (!resDim)
return rewriter.notifyMatchFailure(
op, "has non-dim entry in permutation map");
auto dimInSrc = unstridedResDimToSrcDim.find(resDim.getPosition());
if (dimInSrc == unstridedResDimToSrcDim.end())
return rewriter.notifyMatchFailure(op,
"permutation map result would be made "
"strided by expand_shape folding");
newPermMapResults.push_back(rewriter.getAffineDimExpr(dimInSrc->second));
}
auto newPerm = AffineMap::get(srcRank, 0, newPermMapResults, op.getContext());
if (failed(op.mayUpdateStartingPosition(expand.getSrc().getType(), newPerm)))
return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
SmallVector<Value> newIndices;
// We can use a disjoint linearization if we aren't masking, because then all
// indicators show that the start position will be in bounds.
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
op.getIndices(), newIndices,
/*startsInbounds=*/!op.getMask());
op.updateStartingPosition(rewriter, expand.getViewSource(), newIndices,
AffineMapAttr::get(newPerm));
return success();
}
LogicalResult TransferOpOfCollapseShapeOpFolder::matchAndRewrite(
VectorTransferOpInterface op, PatternRewriter &rewriter) const {
auto collapse = op.getBase().getDefiningOp<memref::CollapseShapeOp>();
if (!collapse)
return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
if (!op.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(op,
"non-minor identity permutation map");
if (op.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(op, "out of bounds dimension");
int64_t srcRank = collapse.getSrc().getType().getRank();
int64_t vecRank = op.getVectorType().getRank();
if (srcRank < vecRank)
return rewriter.notifyMatchFailure(op,
"source rank is less than vector rank");
// Note: no - 1 on the rank here. While we could treat the collapse of [1, 1,
// N] into N as a special case, that is left as future work for those who need
// such a pattern.
SmallVector<ReassociationIndices> reassocs =
collapse.getReassociationIndices();
if (!hasTrivialReassociationSuffix(reassocs, vecRank))
return rewriter.notifyMatchFailure(
op, "collapse_shape folding would split a transfer dimension");
AffineMap newPerm =
AffineMap::getMinorIdentityMap(srcRank, vecRank, op.getContext());
if (failed(
op.mayUpdateStartingPosition(collapse.getSrc().getType(), newPerm)))
return rewriter.notifyMatchFailure(op, "failed op-specific preconditions");
SmallVector<Value> newIndices;
memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
op.getIndices(), newIndices,
/*startsInbounds=*/!op.getMask());
op.updateStartingPosition(rewriter, collapse.getViewSource(), newIndices,
AffineMapAttr::get(newPerm));
return success();
}
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
patterns
.add<AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
IndexedMemCopyOpOfExpandShapeOpFolder,
IndexedMemCopyOpOfCollapseShapeOpFolder, TransferOpOfSubViewOpFolder,
TransferOpOfExpandShapeOpFolder, TransferOpOfCollapseShapeOpFolder,
SubViewOfSubViewFolder>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
struct FoldMemRefAliasOpsPass final
: public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
void runOnOperation() override;
};
} // namespace
void FoldMemRefAliasOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateFoldMemRefAliasOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}