blob: ae9824f728da4dd2e0d9a91cdb4234c6d0f0caf3 [file] [log] [blame]
//===- ReshapeOpsUtils.h - Utilities used by reshape ops --*- C++ -*------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines utilities and common canonicalization patterns for
// reshape operations.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
namespace mlir {
using ReassociationIndices = SmallVector<int64_t, 2>;
using ReassociationIndicesRef = ArrayRef<int64_t>;
using ReassociationExprs = SmallVector<AffineExpr, 2>;
/// Attribute name for the ArrayAttr which encodes reassociation indices.
constexpr StringRef getReassociationAttrName() { return "reassociation"; }
/// Compose reassociation maps that are used in pair of reshape ops where one
/// is a producer and other is the consumer. Only valid to use this method when
/// both the producer and consumer are collapsing dimensions or both are
/// expanding dimensions.
///
/// For example,
/// producerReassociation = [[0, 1], [2], [3, 4]]
/// consumerReassociation = [[0, 1], [2]]
///
/// is folded into
///
/// result = [[0, 1, 2], [3, 4]].
std::optional<SmallVector<ReassociationIndices>> composeReassociationIndices(
ArrayRef<ReassociationIndices> producerReassociations,
ArrayRef<ReassociationIndices> consumerReassociations,
MLIRContext *context);
/// Convert reassociation indices to affine expressions.
SmallVector<SmallVector<AffineExpr, 2>, 2> convertReassociationIndicesToExprs(
MLIRContext *context, ArrayRef<ReassociationIndices> reassociationIndices);
/// Constructs affine maps out of Array<Array<AffineExpr>>.
SmallVector<AffineMap, 4>
getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation);
/// Wraps a list of reassociations in an ArrayAttr.
ArrayAttr
getReassociationIndicesAttribute(OpBuilder &b,
ArrayRef<ReassociationIndices> reassociation);
/// Convert Array<Array<AffineExpr>> to Array<Array<int64_t>>.
SmallVector<ReassociationIndices, 2> convertReassociationMapsToIndices(
OpBuilder &b, ArrayRef<ReassociationExprs> reassociationExprs);
/// Return the reassociations maps to use to reshape given the source type and
/// the target type when possible. Return std::nullopt when this computation
/// failed.
std::optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType);
/// Returns the reassociation maps to collapse `sourceShape` to `targetShape` if
/// possible.
std::optional<SmallVector<ReassociationIndices>>
getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> targetShape);
/// Return true if the reassociation specification is valid, false otherwise.
/// When false, the `invalidIndex` integer pointer is optionally filled with the
/// index of the offending reassociation map.
bool isReassociationValid(ArrayRef<AffineMap> reassociation,
int *invalidIndex = nullptr);
template <typename ReshapeOpTy, typename InverseReshapeOpTy>
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
ArrayRef<Attribute> operands) {
if (reshapeOp.getSrcType() == reshapeOp.getType())
return reshapeOp.getSrc();
// Fold producer-consumer reshape ops where the operand type of the
// producer is same as the return type of the consumer.
auto reshapeSrcOp =
reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.getSrc();
// Reshape of a constant can be replaced with a new constant.
if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
return nullptr;
}
/// Common verifier for reshape-like types. Fills `expandedType` and
///`collapsedType` with the proper `src` or `result` type.
template <typename Op, typename T>
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType,
T collapsedType, bool isExpansion) {
unsigned expandedRank = expandedType.getRank();
unsigned collapsedRank = collapsedType.getRank();
if (expandedRank < collapsedRank)
return op.emitOpError("expected the expanded type, ")
<< expandedType << " to have a higher (or same) rank "
<< "than the collapsed type, " << collapsedType << '.';
if (collapsedRank != op.getReassociation().size())
return op.emitOpError("expected collapsed rank (")
<< collapsedRank << ") to equal the number of reassociation maps ("
<< op.getReassociation().size() << ").";
auto maps = op.getReassociationMaps();
for (auto it : llvm::enumerate(maps))
if (it.value().getNumDims() != expandedRank)
return op.emitOpError("expected reassociation map #")
<< it.index() << " to have size equal to the expanded rank ("
<< expandedRank << "), but it is " << it.value().getNumDims()
<< '.';
int invalidIdx = 0;
if (!isReassociationValid(maps, &invalidIdx))
return op.emitOpError("expected reassociation map #")
<< invalidIdx << " to be valid and contiguous.";
return reshapeLikeShapesAreCompatible(
[&](const Twine &msg) { return op->emitOpError(msg); },
collapsedType.getShape(), expandedType.getShape(),
op.getReassociationIndices(), isExpansion);
}
/// Verify that shapes of the reshaped types using following rules
/// 1) if a dimension in the collapsed type is static, then the corresponding
/// dimensions in the expanded shape should be
/// a) static
/// b) the product should be same as the collaped shape.
/// 2) if a dimension in the collaped type is dynamic, one and only one of the
/// corresponding dimensions in the expanded type should be dynamic. This
/// rule is only needed with reshape operations that are expanding.
LogicalResult reshapeLikeShapesAreCompatible(
function_ref<LogicalResult(const Twine &)> emitError,
ArrayRef<int64_t> collapsedShape, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociationMaps, bool isExpandingReshape);
/// Returns true iff the type is a MemRefType and has a non-identity layout.
bool hasNonIdentityLayout(Type type);
/// Pattern to collapse producer/consumer reshape ops that are both collapsing
/// dimensions or are both expanding dimensions.
template <typename ReshapeOpTy>
struct ComposeReassociativeReshapeOps : public OpRewritePattern<ReshapeOpTy> {
using OpRewritePattern<ReshapeOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp,
PatternRewriter &rewriter) const override {
auto srcReshapeOp =
reshapeOp.getSrc().template getDefiningOp<ReshapeOpTy>();
if (!srcReshapeOp)
return failure();
ShapedType resultType = reshapeOp.getResultType();
if (hasNonIdentityLayout(srcReshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getSrc().getType()) ||
hasNonIdentityLayout(reshapeOp.getResult().getType()))
return failure();
std::optional<SmallVector<ReassociationIndices>> reassociationIndices =
composeReassociationIndices(srcReshapeOp.getReassociationIndices(),
reshapeOp.getReassociationIndices(),
rewriter.getContext());
if (!reassociationIndices)
return failure();
rewriter.replaceOpWithNewOp<ReshapeOpTy>(
reshapeOp, resultType, srcReshapeOp.getSrc(), *reassociationIndices);
return success();
}
};
/// Pattern to compose
/// `collapse_shape(expand_shape(%src, reassociation_1), reassociation_2)`.
/// In that case both `srcType` and `resultType` can be expressed as a function
/// of `intermediateType`.
/// In order to demonstrate the approach, let's assume that `rank(srcType) >
/// `rank(resultType)`, i.e. the resulting operation should be `collapse_shape`.
/// In that case, we can iterate over every set of indices in `reassociation_2`
/// and try to find ids of sets of indices in `reassociation_1` that cover it
/// completely.
///
/// Example:
///
/// %0 = tensor.expand_shape %arg [[0], [1], [2, 3]]
/// : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
/// %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
/// : tensor<?x?x?x1xi64> into tensor<?x?xi64>
///
/// can be canonicalized into
///
/// %0 = tensor.collapse_shape %arg [[0, 1], [2]]
/// : tensor<?x?x?xi64> into tensor<?x?xi64>
///
/// because [0] and [1] from `expand_shape` reassociation cover completely
/// `[0, 1]` from `collapse_shape`. If it is impossible to find such union of
/// indices, then we fail.
//
/// When `rank(srcType) < rank(resultType)`, then we just swap `reassociation_1`
/// `reassociation_2` and produce `expand_shape`.
template <typename CollapseOpTy, typename ExpandOpTy, typename CastOpTy>
struct ComposeCollapseOfExpandOp : public OpRewritePattern<CollapseOpTy> {
using OpRewritePattern<CollapseOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseOpTy collapseOp,
PatternRewriter &rewriter) const override {
auto expandOp = collapseOp.getSrc().template getDefiningOp<ExpandOpTy>();
if (!expandOp)
return failure();
ShapedType srcType = expandOp.getSrcType();
ShapedType resultType = collapseOp.getResultType();
if (hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(expandOp.getResult().getType()))
return failure();
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
if (srcType == resultType)
return failure();
SmallVector<ReassociationIndices, 4> higherRankReassociation,
lowerRankReassociation;
if (srcRank > resultRank) {
higherRankReassociation = expandOp.getReassociationIndices();
lowerRankReassociation = collapseOp.getReassociationIndices();
} else {
higherRankReassociation = collapseOp.getReassociationIndices();
lowerRankReassociation = expandOp.getReassociationIndices();
}
size_t higherRankIndicesID = 0;
SmallVector<ReassociationIndices, 4> composedReassociation;
for (const auto &lowerRankIndices : lowerRankReassociation) {
ReassociationIndices composedIndices;
while (higherRankIndicesID < higherRankReassociation.size()) {
auto rightmostIndex =
higherRankReassociation[higherRankIndicesID].back();
if (rightmostIndex > lowerRankIndices.back())
return failure();
composedIndices.push_back(higherRankIndicesID++);
if (rightmostIndex == lowerRankIndices.back())
break;
}
composedReassociation.push_back(composedIndices);
}
if (srcRank > resultRank) {
rewriter.replaceOpWithNewOp<CollapseOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else if (srcRank < resultRank) {
rewriter.replaceOpWithNewOp<ExpandOpTy>(
collapseOp, resultType, expandOp.getSrc(), composedReassociation);
} else {
// Collapses/expansions that do not change the rank are not allowed. Use
// a cast instead.
assert(llvm::equal(srcType.getShape(), resultType.getShape()) &&
"expected same shape");
rewriter.replaceOpWithNewOp<CastOpTy>(collapseOp, resultType,
expandOp.getSrc());
}
return success();
}
};
template <typename ExpandOpTy, typename CollapseOpTy>
struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
using OpRewritePattern<ExpandOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandOpTy expandOp,
PatternRewriter &rewriter) const override {
auto collapseOp = expandOp.getSrc().template getDefiningOp<CollapseOpTy>();
if (!collapseOp)
return failure();
ShapedType srcType = collapseOp.getSrcType();
ShapedType resultType = expandOp.getResultType();
if (hasNonIdentityLayout(expandOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getSrc().getType()) ||
hasNonIdentityLayout(collapseOp.getResult().getType()))
return failure();
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
if (srcType == resultType)
return failure();
auto srcReassociation = collapseOp.getReassociationIndices();
auto resultReassociation = expandOp.getReassociationIndices();
if (srcRank > resultRank) {
auto composedReassociation = findCollapsingReassociation(
srcReassociation, resultReassociation, srcType.getShape(),
resultType.getShape());
if (!composedReassociation)
return failure();
rewriter.replaceOpWithNewOp<CollapseOpTy>(
expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
return success();
}
auto composedReassociation =
findCollapsingReassociation(resultReassociation, srcReassociation,
resultType.getShape(), srcType.getShape());
if (!composedReassociation)
return failure();
rewriter.replaceOpWithNewOp<ExpandOpTy>(
expandOp, resultType, collapseOp.getSrc(), *composedReassociation);
return success();
}
private:
// Attempts to find a way to collapse `srcShape` to `resultShape` by
// collapsing subshapes defined by the reassociation indices.
std::optional<SmallVector<ReassociationIndices>> findCollapsingReassociation(
ArrayRef<ReassociationIndices> srcReassociation,
ArrayRef<ReassociationIndices> resultReassociation,
ArrayRef<int64_t> srcShape, ArrayRef<int64_t> resultShape) const {
SmallVector<ReassociationIndices, 4> composedReassociation;
if (srcReassociation.empty())
return {getReassociationIndicesForCollapse(srcShape, resultShape)};
for (auto item : llvm::zip(srcReassociation, resultReassociation)) {
auto &srcIndices = std::get<0>(item);
auto &resultIndices = std::get<1>(item);
auto srcSubShape = srcShape.slice(srcIndices.front(), srcIndices.size());
auto resultSubShape =
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
if (srcSubShape == resultSubShape)
composedReassociation.push_back(srcIndices);
else
return std::nullopt;
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
auto subShapeReassociation =
getReassociationIndicesForCollapse(srcSubShape, resultSubShape);
if (!subShapeReassociation)
return std::nullopt;
// Remap the subshape indices back to the original srcShape.
for (auto &subshape_indices : *subShapeReassociation) {
ReassociationIndices shape_indices;
for (int64_t index : subshape_indices)
shape_indices.push_back(srcIndices.front() + index);
composedReassociation.push_back(shape_indices);
}
}
return {std::move(composedReassociation)};
}
};
/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular
/// non rank-reducing slice of the collapse_shape output. Try to find which
/// dimensions have been sliced and which dimensions are not sliced (offset = 0,
/// size = dim, size = 1). Note that this conservative as it cannot detect if a
/// dynamic size corresponds to the full tensor dimension or not.
llvm::SmallBitVector getSlicedDimensions(ArrayRef<OpFoldResult> sliceInputShape,
ArrayRef<Range> sliceParams);
/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by
/// inspecting its reassociation indices.
llvm::SmallBitVector
getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
/// Given the parameters for both operations in a `CollapseShape->ExtractSlice`
/// chain and reified source and result shapes of the CollapseShapeOp, this
/// class provides two functions that assist with directly forming the result
/// of the extract slice by "tiling the CollapseShapeOp by 1".
//// Example:
// clang-format off
/// ```
/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32>
/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32>
/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32>
/// ```
/// This class helps build the below IR to replace %2:
/// ```
/// %dest = tensor.empty() : tensor<10x10xf32>
/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> {
/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv)
/// %3:3 = arith.delinearize_index %iv into (3, 7, 11)
///
/// // This function takes %3 (multiIndices) and the parameters for the slice below.
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
///
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
/// tensor<1x10xf32> into tensor<10x10xf32>
/// scf.yield %6 : tensor<10x10xf32>
/// }
/// ```
// clang-format on
class SliceFromCollapseHelper {
public:
SliceFromCollapseHelper(ArrayRef<ReassociationIndices> reassociationIndices,
ArrayRef<OpFoldResult> collapseShapeInputShape,
ArrayRef<OpFoldResult> collapseShapeOutputShape,
ArrayRef<Range> extractSliceParams)
: reassociationIndices(reassociationIndices),
collapseShapeInputShape(collapseShapeInputShape),
collapseShapeOutputShape(collapseShapeOutputShape),
sliceParams(extractSliceParams),
linearizedDimensions(getLinearizedDimensions(reassociationIndices)),
slicedDimensions(getSlicedDimensions(collapseShapeOutputShape,
extractSliceParams)) {}
/// This function takes multi-indices and maps them to ExtractSlice parameters
/// in the index space of the CollapseShape's source tensor. This function's
/// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes,
/// strides)` where `n` the number of "tiled dimensions", which are the
/// dimensions of the output that are linearized by the collapse shape op and
/// are also sliced. Each `D_i` is a tuple that must represent a valid
/// multi-index for the `i-th` tiled dimension. In the example above, there is
/// only one tiled dimension (D_0) and `arith.delinearize_index` produces the
/// multi-index (%3) that would be passed to this function to generate the
/// parameters for the `tensor.extract_slice` op (%4).
SmallVector<Range> getExtractSliceParams(MLIRContext *ctx,
ArrayRef<ValueRange> multiIndices);
/// This function takes indices in the index space of the "tiled dimensions"
/// described above and returns a set of Range variables that describe how the
/// slice should be inserted into the destination. In the example above, `%iv`
/// would be passed to this function to generate the parameters for the
/// `tensor.insert_slice` op producing %6.
SmallVector<Range> getInsertSliceParams(MLIRContext *ctx,
ValueRange tileIndices);
private:
SmallVector<ReassociationIndices> reassociationIndices;
SmallVector<OpFoldResult> collapseShapeInputShape;
SmallVector<OpFoldResult> collapseShapeOutputShape;
SmallVector<Range> sliceParams;
llvm::SmallBitVector linearizedDimensions;
llvm::SmallBitVector slicedDimensions;
};
/// Parameters required to simplify a collapsing reshape op with a rank-reducing
/// slice operation. See `getSimplifyCollapseShapeWithRankReducingSliceInfo`.
struct CollapseShapeRankReducingSliceSimplificationInfo {
/// The shape of the output of the rank-reducing slice.
RankedTensorType sliceResultType;
/// The reassociation indices for the new collapse shape op, if required. If
/// `std::nullopt`, the slice should replace the collapse shape op.
std::optional<SmallVector<ReassociationIndices>> newReassociationIndices;
};
/// A collapsing reshape operation can sometimes be simplified or eliminated by
/// inserting a single rank-reducing slice operation between it and the source
/// tensor. The slice op will either take the place of the source, allowing for
/// a new, simpler reshape op to replace the original, or the reshape op will be
/// completely replaced by the slice result.
///
/// This function returns the parameters required to implement this pattern. If
/// the pattern is not applicable, then failure is returned.
///
/// ### Example:
/// ```
/// %result = tensor.collapse_shape %0 [[0, 1], [2, 3]]
/// : tensor<?x1x30x10xf32> to tensor<?x300xf32>
/// ```
/// can be transformed to
/// ```
/// %tmp = tensor.extract_slice %0 [0, 0, 0, 0]
/// [0, %dim1, 30, 30]
/// [1, 1, 1 1]
/// : tensor<?x1x30x10xf32> to tensor<?x30x10xf32>
/// %result = tensor.collapse_shape %tmp [[0], [1, 2]]
/// : tensor<?x30x10xf32> to tensor<?x300xf32>
/// ```
///
/// ### Example:
/// ```
/// %result = tensor.collapse_shape %1 [[0, 1], [2]]
/// : tensor<?x1x30xf32> to tensor<?x30xf32>
/// ```
/// can be transformed to
/// ```
/// %result = tensor.extract_slice %1 [0, 0, 0]
/// [%dim2, 1, 30]
/// [1, 1, 1]
/// : tensor<?x1x30xf32> to tensor<?x30xf32>
/// ```
FailureOr<CollapseShapeRankReducingSliceSimplificationInfo>
getSimplifyCollapseShapeWithRankReducingSliceInfo(
RankedTensorType sourceType,
ArrayRef<ReassociationIndices> reassociationIndices);
struct PackingMetadata {
SmallVector<int64_t> insertPositions;
SmallVector<int64_t> outerPositions;
SmallVector<ReassociationIndices> reassociations;
};
/// Given a vector of `positions` indices representing desired packing insertion
/// points into a target vector (i.e. pack/unpack.inner_dim_pos), compute the
/// final positions in the target shape as well as the reshape reassociations.
// Note: This should not be called with a large positions array (or the
// implementation needs to be updated to use an N.log N sort instead of
// repeated N^2 counts).
PackingMetadata computePackingMetadata(int64_t packedRank,
ArrayRef<int64_t> innerDimPos);
} // namespace mlir
#endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H