| //===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/Support/LogicalResult.h" |
| |
| using namespace mlir; |
| using namespace mlir::tensor; |
| |
| namespace { |
| /// Fold expand_shape(extract_slice) ops that cancel itself out. |
| struct FoldExpandOfRankReducingExtract |
| : public OpRewritePattern<ExpandShapeOp> { |
| using OpRewritePattern<ExpandShapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, |
| PatternRewriter &rewriter) const override { |
| RankedTensorType resultType = expandShapeOp.getResultType(); |
| auto extractSliceOp = |
| expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>(); |
| if (!extractSliceOp) |
| return failure(); |
| RankedTensorType srcType = extractSliceOp.getSourceType(); |
| |
| // Only cases where the ExpandShapeOp can be folded away entirely are |
| // supported. Moreover, only simple cases where the resulting ExtractSliceOp |
| // has no rank-reduction anymore are supported at the moment. |
| RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( |
| srcType, extractSliceOp.getStaticOffsets(), |
| extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); |
| if (nonReducingExtractType != resultType) |
| return failure(); |
| |
| SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes(); |
| SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides(); |
| rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>( |
| expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, |
| mixedStrides); |
| return success(); |
| } |
| }; |
| |
| /// Fold collapse_shape which only removes static dimensions of size `1` |
| /// into extract_slice. |
| struct FoldUnPaddingCollapseIntoExtract |
| : public OpRewritePattern<tensor::CollapseShapeOp> { |
| using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp, |
| PatternRewriter &rewriter) const override { |
| auto extractSliceOp = |
| collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>(); |
| // Collapse cannot be folded away with multiple users of the extract slice |
| // and it is not necessarily beneficial to only convert the collapse into |
| // another extract slice. |
| if (!extractSliceOp || !extractSliceOp->hasOneUse()) |
| return failure(); |
| |
| // Only fold away simple collapse where all removed dimensions have static |
| // size `1`. |
| SliceVerificationResult res = isRankReducedType( |
| collapseShapeOp.getSrcType(), collapseShapeOp.getResultType()); |
| if (res != SliceVerificationResult::Success) |
| return rewriter.notifyMatchFailure(collapseShapeOp, |
| "expected unpadding collapse"); |
| |
| Value unPaddedExtractSlice = tensor::ExtractSliceOp::create( |
| rewriter, extractSliceOp.getLoc(), collapseShapeOp.getResultType(), |
| extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(), |
| extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); |
| rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice); |
| return success(); |
| } |
| }; |
| |
| /// Fold insert_slice(collapse_shape) ops that cancel itself out. |
| template <typename OpTy> |
| struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> { |
| using OpRewritePattern<OpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| PatternRewriter &rewriter) const override { |
| auto collapseShapeOp = |
| insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>(); |
| if (!collapseShapeOp) |
| return failure(); |
| RankedTensorType srcType = collapseShapeOp.getSrcType(); |
| |
| // Only cases where the CollapseShapeOp can be folded away entirely are |
| // supported. Moreover, only simple cases where the resulting InsertSliceOp |
| // has no rank-reduction anymore are supported at the moment. |
| RankedTensorType nonReducingInsertType = |
| RankedTensorType::get(insertSliceOp.getStaticSizes(), |
| insertSliceOp.getDestType().getElementType()); |
| if (nonReducingInsertType != srcType) |
| return failure(); |
| |
| SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes(); |
| SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides(); |
| rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(), |
| insertSliceOp.getDest(), mixedOffsets, |
| mixedSizes, mixedStrides); |
| return success(); |
| } |
| }; |
| |
| /// Fold expand_shape which only adds static dimensions of size `1` |
| /// into insert_slice. |
| template <typename OpTy> |
| struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> { |
| using OpRewritePattern<OpTy>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(OpTy insertSliceOp, |
| PatternRewriter &rewriter) const override { |
| auto expandShapeOp = insertSliceOp.getSource() |
| .template getDefiningOp<tensor::ExpandShapeOp>(); |
| if (!expandShapeOp) |
| return failure(); |
| |
| // Only fold away simple expansion where all added dimensions have static |
| // size `1`. |
| SliceVerificationResult res = isRankReducedType( |
| expandShapeOp.getResultType(), expandShapeOp.getSrcType()); |
| if (res != SliceVerificationResult::Success) |
| return rewriter.notifyMatchFailure(insertSliceOp, |
| "expected rank increasing expansion"); |
| |
| rewriter.modifyOpInPlace(insertSliceOp, [&]() { |
| insertSliceOp.getSourceMutable().assign(expandShapeOp.getSrc()); |
| }); |
| return success(); |
| } |
| }; |
| |
| /// Pattern to bubble up a tensor.expand_shape op through a producer |
| /// tensor.collapse_shape op that has non intersecting reassociations. |
| struct BubbleUpExpandThroughParallelCollapse |
| : public OpRewritePattern<tensor::ExpandShapeOp> { |
| using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, |
| PatternRewriter &rewriter) const override { |
| auto collapseOp = |
| expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>(); |
| if (!collapseOp) |
| return failure(); |
| auto expandReInds = expandOp.getReassociationIndices(); |
| auto collapseReInds = collapseOp.getReassociationIndices(); |
| |
| // Special case where the collapsed tensor to expand is a 0-D tensor, |
| // then the reassociation maps will be empty and not produce valid results. |
| if (expandReInds.size() == 0) { |
| return failure(); |
| } |
| |
| // Reshapes are parallel to each other (by construction the number of |
| // reassociations specified in the collapse and expand are the same), if at |
| // any position |
| // 1. either the reassociation indices are of the same size, or |
| // 2. either the reassociation in the collapse or the expand is of size 1. |
| ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape(); |
| ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape(); |
| for (auto [expandReassociation, collapseReassociation] : |
| llvm::zip_equal(expandReInds, collapseReInds)) { |
| if (collapseReassociation.size() == expandReassociation.size()) { |
| // Even if the reassociations are the same, the collapse/expand should |
| // result in the same dimensions. i.e 4x8x2 into 64 should be expanded |
| // into 4x8x2 again. In presense of dynamic dimensions one can only |
| // verify "equality" when there is only one dynamic dimension present, |
| // and all other static dimensions are equal. |
| ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice( |
| collapseReassociation.front(), collapseReassociation.size()); |
| int64_t numCollapsedDynamic = |
| llvm::count_if(collapsedStaticShapes, ShapedType::isDynamic); |
| ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice( |
| expandReassociation.front(), expandReassociation.size()); |
| int64_t numExpandedDynamic = |
| llvm::count_if(expandedStaticShapes, ShapedType::isDynamic); |
| if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 || |
| collapsedStaticShapes != expandedStaticShapes) { |
| return failure(); |
| } |
| continue; |
| } |
| // If the reassociations are not same, one or the other needs to be of |
| // size one. |
| if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) |
| return failure(); |
| } |
| |
| // Compute new reassociation indices and expanded/collaped shapes. |
| SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds; |
| Location loc = expandOp->getLoc(); |
| SmallVector<OpFoldResult> sourceSizes = |
| tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); |
| SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape(); |
| SmallVector<OpFoldResult> newExpandSizes; |
| |
| int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0, |
| resultSizeIndex = 0; |
| |
| for (size_t idx = 0, idxEnd = collapseReInds.size(); idx < idxEnd; idx++) { |
| auto &collapseReassociation = collapseReInds[idx]; |
| auto &expandReassociation = expandReInds[idx]; |
| |
| // Case 1. The reassociations are same in the collapse producer |
| // and expand consumer. In the swapped expand, each of the final |
| // dimensions are kept as is in the expand and the collapse. So, |
| // for every element in the `ReassocationIndices` vector add a new |
| // `ReassociationIndices` vector for the swapped expand and collapse |
| // (of size 1). |
| if (collapseReassociation.size() == expandReassociation.size()) { |
| for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| newCollapseReInds.push_back({newCollapseIndex++}); |
| newExpandReInds.push_back({newExpandIndex++}); |
| newExpandSizes.push_back(resultSizes[resultSizeIndex++]); |
| sourceSizeIndex++; |
| } |
| continue; |
| } |
| |
| // Case 2. The `ReassociationIndices` in the collapse is of size > 1 (and |
| // in the expand is of size == 1). In this case, the original dimensions |
| // are preserved on expansion and collapsed subsequently. |
| if (collapseReassociation.size() != 1) { |
| ReassociationIndices newCollapseReassociation; |
| for (size_t i = 0; i < collapseReassociation.size(); ++i) { |
| newCollapseReassociation.push_back(newCollapseIndex++); |
| newExpandReInds.push_back({newExpandIndex++}); |
| newExpandSizes.push_back(sourceSizes[sourceSizeIndex++]); |
| } |
| resultSizeIndex++; |
| newCollapseReInds.push_back(newCollapseReassociation); |
| continue; |
| } |
| |
| // Case 3. The `ReassociationIndices` in the expand is of size > 1 (and |
| // in the collapse is of size == 1). In this case, the expansion happens |
| // first and the expanded dimensions are preserved on collapse. |
| ReassociationIndices newExpandReassociation; |
| for (size_t i = 0; i < expandReassociation.size(); ++i) { |
| newExpandReassociation.push_back(newExpandIndex++); |
| newCollapseReInds.push_back({newCollapseIndex++}); |
| newExpandSizes.push_back(resultSizes[resultSizeIndex++]); |
| } |
| newExpandReInds.push_back(newExpandReassociation); |
| sourceSizeIndex++; |
| } |
| |
| // Swap reshape order. |
| SmallVector<Value> dynamicSizes; |
| SmallVector<int64_t> staticSizes; |
| dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); |
| auto expandResultType = expandOp.getResultType().clone(staticSizes); |
| Value newCollapseSrc = collapseOp.getSrc(); |
| // If the number of reassociation indices in the new `expand_shape` op |
| // matches the number of dimensions of the result, then the expand_shape |
| // is a no-op. |
| if (newExpandReInds.size() != newExpandSizes.size()) { |
| newCollapseSrc = tensor::ExpandShapeOp::create( |
| rewriter, loc, expandResultType, newCollapseSrc, newExpandReInds, |
| newExpandSizes); |
| } |
| |
| // If the number of reassociation indices in the new `collapse_shape` op |
| // matches the number of dimensions of the source, then the collapse_shape |
| // is a no-op. |
| Value replacement = newCollapseSrc; |
| if (newCollapseReInds.size() != newExpandSizes.size()) { |
| replacement = tensor::CollapseShapeOp::create( |
| rewriter, loc, newCollapseSrc, newCollapseReInds); |
| } |
| rewriter.replaceOp(expandOp, replacement); |
| return success(); |
| } |
| }; |
| |
| /// Converts `tensor.extract_slice(tensor.expand_shape)` to |
| /// `tensor.expand_shape(tensor.extract_slice)`. |
| /// |
| /// For this transformation to be possible, the slice must be fully contiguous |
| /// within each reassociation group of the expand_shape. A slice is defined as |
| /// fully contiguous within a reassociation group if after flattening the |
| /// reassociation group to a single 1D range, then the slice taken out of the |
| /// group could be defined as a single contiguous subrange within that range. |
| /// |
| /// Rank reducing slices are not supported. |
| /// |
| /// Example: |
| /// The transformation is possible because each reassociation group has a |
| /// contiguous slice (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4]). |
| /// ``` |
| /// BEFORE: |
| /// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]] |
| /// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32> |
| /// %slice = tensor.extract_slice %reshape ... |
| /// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| /// |
| /// AFTER: |
| /// %slice = tensor.extract_slice %in ... |
| /// tensor<8x16x32xf32> to tensor<8x5x4xf32> |
| /// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]] |
| /// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32> |
| /// ``` |
| /// |
| /// Note - this pattern could be extended to be a swap pattern between |
| /// `tensor.expand_shape` and `tensor.extract_slice`, but is currently |
| /// implemented only as a bubble up pattern for `tensor.extract_slice`. |
| struct BubbleUpExtractSliceThroughExpandShape |
| : public OpRewritePattern<tensor::ExtractSliceOp> { |
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| PatternRewriter &rewriter) const override { |
| auto expandShapeOp = |
| sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); |
| if (!expandShapeOp) { |
| return rewriter.notifyMatchFailure( |
| sliceOp, "tensor.extract_slice source not produced by expand_shape"); |
| } |
| SmallVector<ReassociationIndices> reassociation = |
| expandShapeOp.getReassociationIndices(); |
| |
| SmallVector<OpFoldResult> offsets, sizes, strides; |
| if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation, |
| offsets, sizes, strides))) |
| return failure(); |
| |
| // The shape of the result can be obtained from the sizes passed in. |
| SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); |
| RankedTensorType resultType = sliceOp.getResultType(); |
| |
| // Create a new ExtractSliceOp and ExpandShapeOp. |
| Location loc = sliceOp.getLoc(); |
| Value newSliceOp = tensor::ExtractSliceOp::create( |
| rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides); |
| rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( |
| sliceOp, resultType, newSliceOp, |
| expandShapeOp.getReassociationIndices(), expandedSizes); |
| return success(); |
| } |
| }; |
| |
| /// Converts `tensor.extract_slice(tensor.collapse_shape)` to |
| /// `tensor.collapse_shape(tensor.extract_slice)`. |
| /// |
| /// For this transformation to be possible - after bubbling up, the extraction |
| /// of the contiguous slice must be representable as a single slice obtained via |
| /// tensor.extract_slice within each reassociation group of the src. |
| /// |
| /// In case the size and offset extracted are static then this is possible if |
| /// the following conditions are met within each reassociation group: |
| /// Let T be a tensor of shape [A0, A1, ..., An] (these are the sizes of the |
| /// dimensions in the reassociation group), and let S = [S0, S1, ..., Sn] be the |
| /// shape of a desired slice. A slice of shape S can be extracted as a |
| /// contiguous span of elements if and only if there exists an index k in {0, 1, |
| /// ..., n} such that: |
| /// S_i = 1 for all i < k (that is, all leading dimensions are singleton), |
| /// 1 <= S_k <= A_k (that is, non trivial slicing occurs along exactly |
| /// one dimension), |
| /// S_i = A_i for all i > k (that is, all trailing dimensions are preserved |
| /// in full). |
| /// In other words, the slice shape S must be of the form: |
| /// [ 1, 1, ..., 1, Sk, Ak + 1, Ak + 2, ...,An ] |
| /// |
| /// In case the size and/or offset extracted are dynamic then this is possible |
| /// only if there is single dimension in the reassociation group that has a size |
| /// not equal to 1. |
| /// In other words, the tensor shape must be of the form: |
| /// [ 1, 1, ..., 1, A, 1, ...,1 ] |
| /// Note - it might be possible to enable this pattern for more cases when the |
| /// size/offset are dynamic via performing an analysis of the possible values |
| /// that could be given to the size/offset. |
| /// |
| /// Example: |
| /// The transformation is possible because each reassociation group can be |
| /// represented as a contiguous slice (i.e., [8x16->2x16], [1x7->1x?], |
| /// [20->10]). |
| /// ``` |
| /// BEFORE: |
| /// %collapse = tensor.collapse_shape %src [[0, 1], [2, 3], [4]] ... |
| /// tensor<8x16x1x7x20f32> to tensor<128x7x20xf32> |
| /// %slice = tensor.extract_slice %slice [0, 0, 0][32, %size, 10][1, 1, 1] |
| /// tensor<128x7x20xf32> to tensor<32x?x10xf32> |
| /// |
| /// AFTER: |
| /// %slice = tensor.extract_slice %src [0, 0, 0, 0, 0][2, 16, 1, %size, 10] |
| // [1, 1, 1, 1, 1] : tensor<8x16x1x7x20f32> to tensor<2x16x1x?x10xf32> |
| /// %collapse = tensor.collapse_shape %slice [[0, 1], [2, 3], [4]] ... |
| /// tensor<2x16x1x?x10xf32> to tensor<32x?x10xf32> |
| /// ``` |
| /// |
| /// Negative example: |
| /// The transformation is not possible because we cannot use a single slice to |
| /// represent the reassociation group [2x3x10->???]. If we would want the |
| /// collapse to be after the extraction, we would need to extract multiple |
| /// slices and concat them together. |
| /// ``` |
| /// %collapse = tensor.collapse_shape %src [[0, 1, 2]] : tensor<2x3x10xf32> into |
| /// tensor<60xf32> %extract = tensor.extract_slice %collapse[0][15][1] : |
| /// tensor<60xf32> to tensor<15xf32> |
| /// ``` |
| /// If we would want the collapse to be after the extraction, a possible |
| /// alternate transformation could be to extract multiple slices and concat them |
| /// together: |
| /// ``` |
| /// %extract_1 = tensor.extract_slice %src[0, 0, 0][1, 1, 10] : |
| /// tensor<2x3x10xf32> to tensor <1x1x10xf32> |
| /// %extract_2 = tensor.extract_slice %src[0, 1, 0][1, 1, 5] : |
| /// tensor<2x3x10xf32> to tensor <1x1x5xf32> |
| /// %concat = tosa.concat %extract_1, %extract_2 {axis = 0 : i32} : |
| /// (<1x1x10xf32>, <1x1x5xf32>) -> <1x1x15xf32> |
| /// %collapse = tensor.collapse_shape %concat [[0, 1, 2]] : tensor<1x1x15xf32> |
| /// to tensor<15xf32> |
| /// ``` |
| /// But this is not the intended purpose of the transformation. |
| struct BubbleUpExtractSliceThroughCollapseShape |
| : public OpRewritePattern<tensor::ExtractSliceOp> { |
| using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern; |
| |
| LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, |
| PatternRewriter &rewriter) const override { |
| auto collapseShapeOp = |
| sliceOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); |
| if (!collapseShapeOp) { |
| return rewriter.notifyMatchFailure( |
| sliceOp, |
| "tensor.extract_slice source not produced by tensor.collapse_shape"); |
| } |
| |
| SmallVector<OpFoldResult> offsets, sizes, strides; |
| if (failed(getExpandedExtractSliceInfo( |
| rewriter, sliceOp, collapseShapeOp.getReassociationIndices(), |
| collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides))) |
| return failure(); |
| |
| Value newSliceOp = tensor::ExtractSliceOp::create( |
| rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets, |
| sizes, strides); |
| rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| sliceOp, sliceOp.getResultType(), newSliceOp, |
| collapseShapeOp.getReassociationIndices()); |
| |
| return success(); |
| } |
| }; |
| |
| } // namespace |
| |
| LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( |
| OpBuilder &b, tensor::ExtractSliceOp sliceOp, |
| ArrayRef<ReassociationIndices> reassociation, |
| SmallVectorImpl<OpFoldResult> &collapsedOffsets, |
| SmallVectorImpl<OpFoldResult> &collapsedSizes, |
| SmallVectorImpl<OpFoldResult> &collapsedStrides) { |
| if (!sliceOp.hasUnitStride()) { |
| return failure(); |
| } |
| |
| SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); |
| |
| if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) { |
| return failure(); |
| } |
| |
| auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, |
| OpFoldResult sliceSize, int64_t inputDim) { |
| if (!isZeroInteger(offset)) |
| return false; |
| ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); |
| FailureOr<bool> maybeEqual = |
| ValueBoundsConstraintSet::areEqual(sliceSize, inputSize); |
| return llvm::succeeded(maybeEqual) && maybeEqual.value(); |
| }; |
| |
| // Check that the slice is contiguous within each reassociation group. |
| // The slice is contiguous only if after the first dimension where a non |
| // unit slice is taken, the slice size on all subsequent dimensions of the |
| // group is equal to the entire size of the dimension. |
| // Examples of contiguous slices: |
| // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] |
| // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] |
| // Examples of non contiguous slices: |
| // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] |
| // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] |
| for (const ReassociationIndices &indices : reassociation) { |
| int64_t i = 0; |
| int64_t e = indices.size(); |
| // Find the first expanded dim after the first dim with non-unit extracted |
| // size. |
| for (; i < e; ++i) { |
| if (!isOneInteger(sizes[indices[i]])) { |
| // +1 to skip the first non-unit size dim. |
| i++; |
| break; |
| } |
| } |
| |
| // Verify that all subsequent dimensions extract the full size of the |
| // source tensor. |
| for (; i < e; ++i) { |
| int64_t expandedDim = indices[i]; |
| if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], |
| expandedDim)) { |
| return failure(); |
| } |
| } |
| } |
| |
| // The tensor.extract_slice before applying the pattern works on the result |
| // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) |
| // referring to the state before applying the pattern are named with the |
| // prefix "expanded", and ones referring to the state after applying the |
| // pattern are named with the prefix "collapsed". |
| Location loc = sliceOp.getLoc(); |
| SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); |
| SmallVector<OpFoldResult> expandedShape = |
| getMixedSizes(b, loc, sliceOp.getSource()); |
| |
| // Helper variables and function for accumulating the size values. |
| AffineExpr d0, d1, d2; |
| bindDims(b.getContext(), d0, d1, d2); |
| // Multiply two integers. |
| auto mul = [&](OpFoldResult v1, OpFoldResult v2) { |
| auto mulMap = AffineMap::get(2, 0, {d0 * d1}); |
| return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2}); |
| }; |
| |
| // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| // The new tensor.extract_slice will work on a tensor that has has a rank of |
| // ReassociationIndices.size(). In the loop a single offset, size, and |
| // stride value is computed per reassociation group. |
| for (const ReassociationIndices &indices : reassociation) { |
| // collapsedSize will hold the size of the single dim that represents the |
| // reassociation group in the non expanded tensor. |
| OpFoldResult collapsedSize = b.getIndexAttr(1); |
| // The reassocGroupSizes and reassocGroupOffsets are used to create an |
| // affine.linearize_index op to linearize the single offset value required |
| // for this reassociation group. |
| SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; |
| |
| for (long expandedDim : indices) { |
| // reassocGroupSizes and reassocGroupOffsets can be obtained directly |
| // from the expanded state, but the collapsed size requires calculation |
| // as it did not previously exist. |
| reassocGroupSizes.push_back(expandedShape[expandedDim]); |
| reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); |
| collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); |
| } |
| |
| SmallVector<Value> offsetVals = |
| llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { |
| return getValueOrCreateConstantIndexOp(b, loc, ofr); |
| }); |
| OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create( |
| b, loc, offsetVals, reassocGroupSizes, |
| /*disjoint=*/true) |
| .getResult(); |
| collapsedOffsets.push_back(collapsedOffset); |
| collapsedSizes.push_back(collapsedSize); |
| |
| // Only unit stride is supported. |
| collapsedStrides.push_back(b.getIndexAttr(1)); |
| } |
| return success(); |
| } |
| |
| LogicalResult mlir::tensor::getExpandedExtractSliceInfo( |
| OpBuilder &b, tensor::ExtractSliceOp sliceOp, |
| ArrayRef<ReassociationIndices> reassociation, |
| ArrayRef<int64_t> expandedShape, |
| SmallVectorImpl<OpFoldResult> &expandedOffsets, |
| SmallVectorImpl<OpFoldResult> &expandedSizes, |
| SmallVectorImpl<OpFoldResult> &expandedStrides) { |
| if (!sliceOp.hasUnitStride()) { |
| return failure(); |
| } |
| |
| // The tensor.extract_slice before applying the pattern works on the result |
| // of the tensor.collapse_shape, so variables (i.e. inputs for |
| // ExtractSliceOp) referring to the state before applying the pattern are |
| // named with the prefix "collapsed", and ones referring to the state after |
| // applying the pattern are named with the prefix "expanded". |
| SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); |
| SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); |
| if (static_cast<size_t>(sliceOp.getResultType().getRank()) != |
| collapsedSizes.size()) { |
| return failure(); |
| } |
| |
| // Compute new offsets, sizes, and strides for tensor.extract_slice. |
| // The new tensor.extract_slice will work on a tensor that has has a rank |
| // equal to the rank of the src of the collapse_shape. In each iteration of |
| // the loop, the offsets and sizes will be computed per reassociation group. |
| expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1)); |
| for (auto [collapsedSize, collapsedOffset, reassocIndices] : |
| llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) { |
| // CASE #1 - size and/or offset are dynamic. |
| // In this case, the slice can be represented as a contiguous slice only |
| // if there is a single dimension in the reassociation group that has a |
| // size not equal to 1. |
| if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { |
| int nonUnitSizeCount = 0; |
| for (int64_t expandedShapeIdx : reassocIndices) { |
| if (expandedShape[expandedShapeIdx] != 1) { |
| nonUnitSizeCount++; |
| expandedSizes.push_back(collapsedSize); |
| expandedOffsets.push_back(collapsedOffset); |
| continue; |
| } |
| |
| expandedSizes.push_back(b.getIndexAttr(1)); |
| expandedOffsets.push_back(b.getIndexAttr(0)); |
| } |
| |
| if (nonUnitSizeCount != 1) { |
| return failure(); |
| } |
| continue; |
| } |
| |
| // CASE #2 = size and offset are static. |
| // Verify that the slice can be represented as a contiguous slice of the |
| // src of the collapse_shape. |
| // Checking this is done on order of most internal dimensions first, |
| // so traversal is done in reverse order of the reassociation group. |
| // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, |
| // ...,An] then we first find the size and offset for n...k+1 then for k |
| // and then for k-1...0. |
| |
| // currentCollapsedsize and currentCollapsedOffset are initialized with |
| // the original collapsed size and offset and divided by the expanded |
| // shape size in each dimension as we go along the reassociation group. |
| // In essence we are spreading the original collapsed size and offset over |
| // the various expanded slice dimensions. |
| // The variables are used both to check the validity of the slice and to |
| // compute the expanded sizes and offsets. |
| int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); |
| int64_t currentCollapsedOffset = |
| getConstantIntValue(collapsedOffset).value(); |
| SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; |
| ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), |
| reassocIndices.rend()); |
| int64_t idx = 0; |
| int64_t reassocGroupSize = reassocIndices.size(); |
| |
| // First handle the trailing dimensions where the slice size should be |
| // equal to the tensor shape and the offset should be 0 (n...k+1). |
| for (; idx < reassocGroupSize; ++idx) { |
| int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; |
| |
| if (currentCollapsedsize < expandedShapeSize) |
| break; |
| |
| // We need to make sure that the slice size can be set to the shape size |
| // and the offset to 0. |
| if ((currentCollapsedsize % expandedShapeSize) != 0 || |
| (currentCollapsedOffset % expandedShapeSize) != 0) { |
| return failure(); |
| } |
| |
| groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize)); |
| groupExpandedOffsets.push_back(b.getIndexAttr(0)); |
| |
| currentCollapsedsize /= expandedShapeSize; |
| currentCollapsedOffset /= expandedShapeSize; |
| } |
| |
| // Now handle the first dim where slicing occurs on (k). |
| if (idx < reassocGroupSize) { |
| int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; |
| int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| // We need to make sure that the slice size in this dim + offset will |
| // not exceed the shape size. |
| if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { |
| return failure(); |
| } |
| groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize)); |
| groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); |
| currentCollapsedOffset /= expandedShapeSize; |
| } |
| |
| // Now handle the leading dimensions where the slice size is equal to 1 |
| // (k-1...0). |
| // The size for these dimensions must be 1 because of how we constructed |
| // the slice size of the expanded shape. We spread the original collapsed |
| // size over the expanded shape sizes until we reached dimension k where |
| // the remaining size was smaller than the expanded shape size, and spread |
| // the remaining size on it. So, now we are left with only 1s. |
| for (idx++; idx < reassocGroupSize; ++idx) { |
| int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; |
| int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; |
| groupExpandedSizes.push_back(b.getIndexAttr(1)); |
| groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); |
| currentCollapsedOffset /= expandedShapeSize; |
| } |
| expandedSizes.append(groupExpandedSizes.rbegin(), |
| groupExpandedSizes.rend()); |
| expandedOffsets.append(groupExpandedOffsets.rbegin(), |
| groupExpandedOffsets.rend()); |
| } |
| return success(); |
| } |
| |
| void mlir::tensor::populateReassociativeReshapeFoldingPatterns( |
| RewritePatternSet &patterns) { |
| patterns |
| .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract, |
| FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>, |
| FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>, |
| FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>, |
| FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>( |
| patterns.getContext()); |
| } |
| |
| void mlir::tensor::populateBubbleUpExpandShapePatterns( |
| RewritePatternSet &patterns) { |
| patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext()); |
| } |
| |
| void mlir::tensor::populateBubbleUpExtractSliceOpPatterns( |
| RewritePatternSet &patterns) { |
| patterns.add<BubbleUpExtractSliceThroughExpandShape, |
| BubbleUpExtractSliceThroughCollapseShape>(patterns.getContext()); |
| } |