blob: dafbd613d09332f9bf22f3858d9ebfc190c6285a [file]
//===- FoldMemRefAliasOps.cpp - Fold memref alias ops ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass folds loading/storing from/to subview ops into
// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/IR/MemoryAccessOpInterfaces.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#define DEBUG_TYPE "fold-memref-alias-ops"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_FOLDMEMREFALIASOPSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//
/// Deterimine if the last N indices of `reassocitaion` are trivial - that is,
/// check if they all contain exactly one dimension to collape/expand into.
static bool
hasTrivialReassociationSuffix(ArrayRef<ReassociationIndices> reassocs,
int64_t n) {
if (n <= 0)
return true;
return llvm::all_of(
reassocs.take_back(n),
[&](const ReassociationIndices &indices) { return indices.size() == 1; });
}
static bool hasTrailingUnitStrides(memref::SubViewOp subview, int64_t n) {
if (n <= 0)
return true;
return llvm::all_of(subview.getStaticStrides().take_back(n),
[](int64_t s) { return s == 1; });
}
/// Helpers to access the memref operand for each op.
template <typename LoadOrStoreOpTy>
static Value getMemRefOperand(LoadOrStoreOpTy op) {
return op.getMemref();
}
static Value getMemRefOperand(vector::TransferReadOp op) {
return op.getBase();
}
static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
return op.getSrcMemref();
}
static Value getMemRefOperand(vector::LoadOp op) { return op.getBase(); }
static Value getMemRefOperand(vector::StoreOp op) { return op.getBase(); }
static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
static Value getMemRefOperand(vector::TransferWriteOp op) {
return op.getBase();
}
//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//
namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
};
/// Merges expand_shape operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
};
/// Merges collapse_shape operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy loadOp,
PatternRewriter &rewriter) const override;
};
/// Merges subview operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfSubViewOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
};
/// Merges expand_shape operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfExpandShapeOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
};
/// Merges collapse_shape operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfCollapseShapeOpFolder final : public OpRewritePattern<OpTy> {
public:
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy storeOp,
PatternRewriter &rewriter) const override;
};
/// Folds subview(subview(x)) to a single subview(x).
class SubViewOfSubViewFolder : public OpRewritePattern<memref::SubViewOp> {
public:
using OpRewritePattern<memref::SubViewOp>::OpRewritePattern;
LogicalResult matchAndRewrite(memref::SubViewOp subView,
PatternRewriter &rewriter) const override {
auto srcSubView = subView.getSource().getDefiningOp<memref::SubViewOp>();
if (!srcSubView)
return failure();
SmallVector<OpFoldResult> newOffsets, newSizes, newStrides;
if (failed(affine::mergeOffsetsSizesAndStrides(
rewriter, subView.getLoc(), srcSubView, subView,
srcSubView.getDroppedDims(), newOffsets, newSizes, newStrides)))
return failure();
// Replace original op.
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
subView, subView.getType(), srcSubView.getSource(), newOffsets,
newSizes, newStrides);
return success();
}
};
/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern
/// is folds subview on src and dst memref of the copy.
class NVGPUAsyncCopyOpSubViewOpFolder final
: public OpRewritePattern<nvgpu::DeviceAsyncCopyOp> {
public:
using OpRewritePattern<nvgpu::DeviceAsyncCopyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp,
PatternRewriter &rewriter) const override;
};
/// Merges subview operations with load/store like operations unless such a
/// merger would cause the strides between dimensions accessed by that operaton
/// to change.
struct AccessOpOfSubViewOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merge a memref.expand_shape operation with an operation that accesses a
/// memref by index unless that operation accesss more than one dimension of
/// memory and any dimension other than the outermost dimension accessed this
/// way would be merged. This prevents issuses from arising with, say, a
/// vector.load of a 4x2 vector having the two trailing dimensions of the access
/// get merged.
struct AccessOpOfExpandShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges an operation that accesses a memref by index with a
/// memref.collapse_shape, unless this would break apart a dimension other than
/// the outermost one that an operation accesses. This prevents, for example,
/// transforming a load of a 3x8 vector from a 6x8 memref into a load
/// from a 3x4x2 memref (as this would require special handling and could lead
/// to invalid IR if that higher-dimensional memref comes from a subview) but
/// does permit turning a load of a length-8 vector from a 3x8 memref into a
/// load from a 3x2x8 one.
struct AccessOpOfCollapseShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedAccessOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.subview operations present on the source or destination
/// operands of indexed memory copy operations (DMA operations) into those
/// operations. This is perfromed unconditionally, since folding in a subview
/// cannot change the starting position of the copy, which is what the
/// memref/index pair represent in DMA operations.
struct IndexedMemCopyOpOfSubViewOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.expand_shape operations that are present on the source or
/// destination of an indexed memory copy/DMA into the memref/index arguments of
/// that DMA. As with subviews, this can be done unconditionally.
struct IndexedMemCopyOpOfExpandShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
/// Merges memref.collapse_shape operations that are present on the source or
/// destination of an indexed memory copy/DMA into the memref/index arguments of
/// that DMA. As with subviews, this can be done unconditionally.
struct IndexedMemCopyOpOfCollapseShapeOpFolder final
: OpInterfaceRewritePattern<memref::IndexedMemCopyOpInterface> {
using Base::Base;
LogicalResult matchAndRewrite(memref::IndexedMemCopyOpInterface op,
PatternRewriter &rewriter) const override;
};
} // namespace
template <typename XferOp>
static LogicalResult
preconditionsFoldSubViewOpImpl(RewriterBase &rewriter, XferOp xferOp,
memref::SubViewOp subviewOp) {
static_assert(
!llvm::is_one_of<vector::TransferReadOp, vector::TransferWriteOp>::value,
"must be a vector transfer op");
if (xferOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(xferOp, "out of bounds transfer dim");
if (!subviewOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
xferOp, "non-1 stride subview, need to track strides in folded memref");
}
return success();
}
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
Operation *op,
memref::SubViewOp subviewOp) {
return success();
}
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
vector::TransferReadOp readOp,
memref::SubViewOp subviewOp) {
return preconditionsFoldSubViewOpImpl(rewriter, readOp, subviewOp);
}
static LogicalResult preconditionsFoldSubViewOp(RewriterBase &rewriter,
vector::TransferWriteOp writeOp,
memref::SubViewOp subviewOp) {
return preconditionsFoldSubViewOpImpl(rewriter, writeOp, subviewOp);
}
template <typename OpTy>
LogicalResult LoadOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
return rewriter.notifyMatchFailure(loadOp, "not a subview producer");
LogicalResult preconditionResult =
preconditionsFoldSubViewOp(rewriter, loadOp, subViewOp);
if (failed(preconditionResult))
return preconditionResult;
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loadOp.getLoc(), subViewOp.getMixedOffsets(),
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
loadOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, subViewOp.getSource(), sourceIndices, op.getNontemporal());
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices);
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
.Case([&](vector::TransferReadOp op) {
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), subViewOp.getSource(), sourceIndices,
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
op.getPadding(), op.getMask(), op.getInBoundsAttr());
})
.Case([&](nvgpu::LdMatrixOp op) {
rewriter.replaceOpWithNewOp<nvgpu::LdMatrixOp>(
op, op.getType(), subViewOp.getSource(), sourceIndices,
op.getTranspose(), op.getNumTiles());
})
.DefaultUnreachable("unexpected operation");
return success();
}
template <typename OpTy>
LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
auto expandShapeOp =
getMemRefOperand(loadOp).template getDefiningOp<memref::ExpandShapeOp>();
if (!expandShapeOp)
return failure();
// For vector::TransferReadOp, validate preconditions before creating any IR.
// resolveSourceIndicesExpandShape creates new ops, so all checks that can
// fail must happen before that call to avoid "pattern returned failure but
// IR did change" errors (caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS).
SmallVector<AffineExpr> transferReadNewResults;
if (auto transferOp =
dyn_cast<vector::TransferReadOp>(loadOp.getOperation())) {
const int64_t vectorRank = transferOp.getVectorType().getRank();
const int64_t sourceRank =
cast<MemRefType>(expandShapeOp.getViewSource().getType()).getRank();
if (sourceRank < vectorRank)
return failure();
// We can only fold if the permutation map uses only the least significant
// dimension from each expanded reassociation group.
for (AffineExpr result : transferOp.getPermutationMap().getResults()) {
bool foundExpr = false;
for (auto reassocationIndices :
llvm::enumerate(expandShapeOp.getReassociationIndices())) {
auto reassociation = reassocationIndices.value();
AffineExpr dim = getAffineDimExpr(
reassociation[reassociation.size() - 1], rewriter.getContext());
if (dim == result) {
transferReadNewResults.push_back(getAffineDimExpr(
reassocationIndices.index(), rewriter.getContext()));
foundExpr = true;
break;
}
}
if (!foundExpr)
return failure();
}
}
SmallVector<Value> sourceIndices;
// memref.load guarantees that indexes start inbounds while the vector
// operations don't. This impacts if our linearization is `disjoint`
resolveSourceIndicesExpandShape(loadOp.getLoc(), rewriter, expandShapeOp,
loadOp.getIndices(), sourceIndices,
isa<memref::LoadOp>(loadOp.getOperation()));
return llvm::TypeSwitch<Operation *, LogicalResult>(loadOp)
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
return success();
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
return success();
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), expandShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
return success();
})
.Case([&](vector::TransferReadOp op) {
const int64_t sourceRank = sourceIndices.size();
auto newMap = AffineMap::get(sourceRank, 0, transferReadNewResults,
op.getContext());
rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
op, op.getVectorType(), expandShapeOp.getViewSource(),
sourceIndices, newMap, op.getPadding(), op.getMask(),
op.getInBounds());
return success();
})
.DefaultUnreachable("unexpected operation");
}
template <typename OpTy>
LogicalResult LoadOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
OpTy loadOp, PatternRewriter &rewriter) const {
auto collapseShapeOp = getMemRefOperand(loadOp)
.template getDefiningOp<memref::CollapseShapeOp>();
if (!collapseShapeOp)
return failure();
SmallVector<Value> sourceIndices;
resolveSourceIndicesCollapseShape(loadOp.getLoc(), rewriter, collapseShapeOp,
loadOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(loadOp)
.Case([&](memref::LoadOp op) {
rewriter.replaceOpWithNewOp<memref::LoadOp>(
loadOp, collapseShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::LoadOp op) {
rewriter.replaceOpWithNewOp<vector::LoadOp>(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::MaskedLoadOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedLoadOp>(
op, op.getType(), collapseShapeOp.getViewSource(), sourceIndices,
op.getMask(), op.getPassThru());
})
.DefaultUnreachable("unexpected operation");
return success();
}
template <typename OpTy>
LogicalResult StoreOpOfSubViewOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
auto subViewOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
if (!subViewOp)
return rewriter.notifyMatchFailure(storeOp, "not a subview producer");
LogicalResult preconditionResult =
preconditionsFoldSubViewOp(rewriter, storeOp, subViewOp);
if (failed(preconditionResult))
return preconditionResult;
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, storeOp.getLoc(), subViewOp.getMixedOffsets(),
subViewOp.getMixedStrides(), subViewOp.getDroppedDims(),
storeOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, op.getValue(), subViewOp.getSource(), sourceIndices,
op.getNontemporal());
})
.Case([&](vector::TransferWriteOp op) {
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
op, op.getValue(), subViewOp.getSource(), sourceIndices,
AffineMapAttr::get(expandDimsToRank(
op.getPermutationMap(), subViewOp.getSourceType().getRank(),
subViewOp.getDroppedDims())),
op.getMask(), op.getInBoundsAttr());
})
.Case([&](vector::StoreOp op) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), subViewOp.getSource(), sourceIndices);
})
.Case([&](vector::MaskedStoreOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, subViewOp.getSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
.DefaultUnreachable("unexpected operation");
return success();
}
template <typename OpTy>
LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
auto expandShapeOp =
getMemRefOperand(storeOp).template getDefiningOp<memref::ExpandShapeOp>();
if (!expandShapeOp)
return failure();
SmallVector<Value> sourceIndices;
// memref.store guarantees that indexes start inbounds while the vector
// operations don't. This impacts if our linearization is `disjoint`
resolveSourceIndicesExpandShape(storeOp.getLoc(), rewriter, expandShapeOp,
storeOp.getIndices(), sourceIndices,
isa<memref::StoreOp>(storeOp.getOperation()));
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), expandShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::StoreOp op) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), expandShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::MaskedStoreOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, expandShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
.DefaultUnreachable("unexpected operation");
return success();
}
template <typename OpTy>
LogicalResult StoreOpOfCollapseShapeOpFolder<OpTy>::matchAndRewrite(
OpTy storeOp, PatternRewriter &rewriter) const {
auto collapseShapeOp = getMemRefOperand(storeOp)
.template getDefiningOp<memref::CollapseShapeOp>();
if (!collapseShapeOp)
return failure();
SmallVector<Value> sourceIndices;
resolveSourceIndicesCollapseShape(storeOp.getLoc(), rewriter, collapseShapeOp,
storeOp.getIndices(), sourceIndices);
llvm::TypeSwitch<Operation *, void>(storeOp)
.Case([&](memref::StoreOp op) {
rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::StoreOp op) {
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, op.getValueToStore(), collapseShapeOp.getViewSource(),
sourceIndices, op.getNontemporal());
})
.Case([&](vector::MaskedStoreOp op) {
rewriter.replaceOpWithNewOp<vector::MaskedStoreOp>(
op, collapseShapeOp.getViewSource(), sourceIndices, op.getMask(),
op.getValueToStore());
})
.DefaultUnreachable("unexpected operation");
return success();
}
LogicalResult
AccessOpOfSubViewOpFolder::matchAndRewrite(memref::IndexedAccessOpInterface op,
PatternRewriter &rewriter) const {
auto subview = op.getAccessedMemref().getDefiningOp<memref::SubViewOp>();
if (!subview)
return rewriter.notifyMatchFailure(op, "not accessing a subview");
SmallVector<int64_t> accessedShape = op.getAccessedShape();
// Note the subtle difference between accesedShape = {1} and accessedShape =
// {} here. The former prevents us from fdolding in a subview that doesn't
// have a unit stride on the final dimension, while the latter does not (since
// it indices scalar accesss).
int64_t accessedDims = accessedShape.size();
if (!hasTrailingUnitStrides(subview, accessedDims))
return rewriter.notifyMatchFailure(
op, "non-unit stride on accessed dimensions");
llvm::SmallBitVector droppedDims = subview.getDroppedDims();
int64_t sourceRank = subview.getSourceType().getRank();
// Ignore outermost access dimension - we only care about dropped dimensions
// between the accessed op's results, as those could break the accessing op's
// sematics.
int64_t secondAccessedDim = sourceRank - (accessedDims - 1);
if (secondAccessedDim < sourceRank) {
for (int64_t d : llvm::seq(secondAccessedDim, sourceRank)) {
if (droppedDims.test(d))
return rewriter.notifyMatchFailure(
op, "reintroducing dropped dimension " + Twine(d) +
" would break access op semantics");
}
}
SmallVector<Value> sourceIndices;
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), subview.getMixedOffsets(),
subview.getMixedStrides(), droppedDims, op.getIndices(), sourceIndices);
std::optional<SmallVector<Value>> newValues =
op.updateMemrefAndIndices(rewriter, subview.getSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult AccessOpOfExpandShapeOpFolder::matchAndRewrite(
memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
auto expand = op.getAccessedMemref().getDefiningOp<memref::ExpandShapeOp>();
if (!expand)
return rewriter.notifyMatchFailure(op, "not accessing an expand_shape");
SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
ArrayRef<int64_t> accessedShape = rawAccessedShape;
// Cut off the leading dimension, since we don't care about monifying its
// strides.
if (!accessedShape.empty())
accessedShape = accessedShape.drop_front();
SmallVector<ReassociationIndices, 4> reassocs =
expand.getReassociationIndices();
if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
return rewriter.notifyMatchFailure(
op,
"expand_shape folding would merge semanvtically important dimensions");
SmallVector<Value> sourceIndices;
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, expand,
op.getIndices(), sourceIndices,
op.hasInboundsIndices());
std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
rewriter, expand.getViewSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult AccessOpOfCollapseShapeOpFolder::matchAndRewrite(
memref::IndexedAccessOpInterface op, PatternRewriter &rewriter) const {
auto collapse =
op.getAccessedMemref().getDefiningOp<memref::CollapseShapeOp>();
if (!collapse)
return rewriter.notifyMatchFailure(op, "not accessing a collapse_shape");
SmallVector<int64_t> rawAccessedShape = op.getAccessedShape();
ArrayRef<int64_t> accessedShape = rawAccessedShape;
// Cut off the leading dimension, since we don't care about its strides being
// modified and we know that the dimensions within its reassociation group, if
// it's non-trivial, must be contiguous.
if (!accessedShape.empty())
accessedShape = accessedShape.drop_front();
SmallVector<ReassociationIndices, 4> reassocs =
collapse.getReassociationIndices();
if (!hasTrivialReassociationSuffix(reassocs, accessedShape.size()))
return rewriter.notifyMatchFailure(op,
"collapse_shape folding would merge "
"semanvtically important dimensions");
SmallVector<Value> sourceIndices;
memref::resolveSourceIndicesCollapseShape(op.getLoc(), rewriter, collapse,
op.getIndices(), sourceIndices);
std::optional<SmallVector<Value>> newValues = op.updateMemrefAndIndices(
rewriter, collapse.getViewSource(), sourceIndices);
if (newValues)
rewriter.replaceOp(op, *newValues);
return success();
}
LogicalResult IndexedMemCopyOpOfSubViewOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
auto srcSubview = op.getSrc().getDefiningOp<memref::SubViewOp>();
auto dstSubview = op.getDst().getDefiningOp<memref::SubViewOp>();
if (!srcSubview && !dstSubview)
return rewriter.notifyMatchFailure(
op, "no subviews found on indexed copy inputs");
Value newSrc = op.getSrc();
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = op.getDst();
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcSubview) {
newSrc = srcSubview.getSource();
newSrcIndices.clear();
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), srcSubview.getMixedOffsets(),
srcSubview.getMixedStrides(), srcSubview.getDroppedDims(),
op.getSrcIndices(), newSrcIndices);
}
if (dstSubview) {
newDst = dstSubview.getSource();
newDstIndices.clear();
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, op.getLoc(), dstSubview.getMixedOffsets(),
dstSubview.getMixedStrides(), dstSubview.getDroppedDims(),
op.getDstIndices(), newDstIndices);
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult IndexedMemCopyOpOfExpandShapeOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
auto srcExpand = op.getSrc().getDefiningOp<memref::ExpandShapeOp>();
auto dstExpand = op.getDst().getDefiningOp<memref::ExpandShapeOp>();
if (!srcExpand && !dstExpand)
return rewriter.notifyMatchFailure(
op, "no expand_shapes found on indexed copy inputs");
Value newSrc = op.getSrc();
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = op.getDst();
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcExpand) {
newSrc = srcExpand.getViewSource();
newSrcIndices.clear();
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, srcExpand,
op.getSrcIndices(), newSrcIndices,
/*startsInbounds=*/true);
}
if (dstExpand) {
newDst = dstExpand.getViewSource();
newDstIndices.clear();
memref::resolveSourceIndicesExpandShape(op.getLoc(), rewriter, dstExpand,
op.getDstIndices(), newDstIndices,
/*startsInbounds=*/true);
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult IndexedMemCopyOpOfCollapseShapeOpFolder::matchAndRewrite(
memref::IndexedMemCopyOpInterface op, PatternRewriter &rewriter) const {
auto srcCollapse = op.getSrc().getDefiningOp<memref::CollapseShapeOp>();
auto dstCollapse = op.getDst().getDefiningOp<memref::CollapseShapeOp>();
if (!srcCollapse && !dstCollapse)
return rewriter.notifyMatchFailure(
op, "no collapse_shapes found on indexed copy inputs");
Value newSrc = op.getSrc();
SmallVector<Value> newSrcIndices = llvm::to_vector(op.getSrcIndices());
Value newDst = op.getDst();
SmallVector<Value> newDstIndices = llvm::to_vector(op.getDstIndices());
if (srcCollapse) {
newSrc = srcCollapse.getViewSource();
newSrcIndices.clear();
memref::resolveSourceIndicesCollapseShape(
op.getLoc(), rewriter, srcCollapse, op.getSrcIndices(), newSrcIndices);
}
if (dstCollapse) {
newDst = dstCollapse.getViewSource();
newDstIndices.clear();
memref::resolveSourceIndicesCollapseShape(
op.getLoc(), rewriter, dstCollapse, op.getDstIndices(), newDstIndices);
}
op.setMemrefsAndIndices(rewriter, newSrc, newSrcIndices, newDst,
newDstIndices);
return success();
}
LogicalResult NVGPUAsyncCopyOpSubViewOpFolder::matchAndRewrite(
nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n");
auto srcSubViewOp =
copyOp.getSrc().template getDefiningOp<memref::SubViewOp>();
auto dstSubViewOp =
copyOp.getDst().template getDefiningOp<memref::SubViewOp>();
if (!(srcSubViewOp || dstSubViewOp))
return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for "
"source or destination");
// If the source is a subview, we need to resolve the indices.
SmallVector<Value> foldedSrcIndices(copyOp.getSrcIndices().begin(),
copyOp.getSrcIndices().end());
if (srcSubViewOp) {
LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n");
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(),
srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(),
copyOp.getSrcIndices(), foldedSrcIndices);
}
// If the destination is a subview, we need to resolve the indices.
SmallVector<Value> foldedDstIndices(copyOp.getDstIndices().begin(),
copyOp.getDstIndices().end());
if (dstSubViewOp) {
LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n");
affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(),
dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(),
copyOp.getDstIndices(), foldedDstIndices);
}
// Replace the copy op with a new copy op that uses the source and destination
// of the subview.
rewriter.replaceOpWithNewOp<nvgpu::DeviceAsyncCopyOp>(
copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()),
(dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()),
foldedDstIndices,
(srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()),
foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(),
copyOp.getBypassL1Attr());
return success();
}
void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) {
patterns.add<
// Interface-based patterns to which we will be migrating.
AccessOpOfSubViewOpFolder, AccessOpOfExpandShapeOpFolder,
AccessOpOfCollapseShapeOpFolder, IndexedMemCopyOpOfSubViewOpFolder,
IndexedMemCopyOpOfExpandShapeOpFolder,
IndexedMemCopyOpOfCollapseShapeOpFolder,
// The old way of doing things. Don't add more of these.
LoadOpOfSubViewOpFolder<nvgpu::LdMatrixOp>,
LoadOpOfSubViewOpFolder<vector::LoadOp>,
LoadOpOfSubViewOpFolder<vector::MaskedLoadOp>,
LoadOpOfSubViewOpFolder<vector::TransferReadOp>,
StoreOpOfSubViewOpFolder<vector::TransferWriteOp>,
StoreOpOfSubViewOpFolder<vector::StoreOp>,
StoreOpOfSubViewOpFolder<vector::MaskedStoreOp>,
LoadOpOfExpandShapeOpFolder<vector::LoadOp>,
LoadOpOfExpandShapeOpFolder<vector::MaskedLoadOp>,
LoadOpOfExpandShapeOpFolder<vector::TransferReadOp>,
StoreOpOfExpandShapeOpFolder<vector::StoreOp>,
StoreOpOfExpandShapeOpFolder<vector::MaskedStoreOp>,
LoadOpOfCollapseShapeOpFolder<vector::LoadOp>,
LoadOpOfCollapseShapeOpFolder<vector::MaskedLoadOp>,
StoreOpOfCollapseShapeOpFolder<vector::StoreOp>,
StoreOpOfCollapseShapeOpFolder<vector::MaskedStoreOp>,
SubViewOfSubViewFolder, NVGPUAsyncCopyOpSubViewOpFolder>(
patterns.getContext());
}
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
struct FoldMemRefAliasOpsPass final
: public memref::impl::FoldMemRefAliasOpsPassBase<FoldMemRefAliasOpsPass> {
void runOnOperation() override;
};
} // namespace
void FoldMemRefAliasOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateFoldMemRefAliasOpPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}