| //===- ExtractSliceFromReshapeUtils.cpp - Slice reshape rewrites ----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements rewrites that replace slices of reshape results with |
| // aggregated slices of the reshape source. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Arith/Utils/Utils.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" |
| #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "llvm/ADT/STLExtras.h" |
| |
| using namespace mlir; |
| using namespace mlir::affine; |
| using namespace mlir::tensor; |
| |
| /// A tuple that represents (dimension number, dimension value). |
| using DimAndIndex = std::tuple<unsigned, Value>; |
| |
| /// Transform `dimAndIndex` from the output index space of a (non-rank-reducing) |
| /// slice described by `sliceParams` into the input index space. |
| static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc, |
| ArrayRef<Range> sliceParams, |
| const DimAndIndex &dimAndIndex) { |
| AffineExpr d0, s0, s1; |
| bindDims(b.getContext(), d0); |
| bindSymbols(b.getContext(), s0, s1); |
| auto [dim, indexValue] = dimAndIndex; |
| assert(dim < sliceParams.size() && "slice should be non rank-reducing"); |
| return std::make_pair( |
| dim, affine::makeComposedAffineApply( |
| b, loc, s0 + d0 * s1, |
| {indexValue, sliceParams[dim].offset, sliceParams[dim].stride})); |
| } |
| |
| /// Transform `dimAndIndex` from the result tensor index space of a |
| /// CollapseShapeOp to the source tensor index space. |
| static ValueRange invertCollapseShapeIndexing( |
| OpBuilder &b, Location loc, ArrayRef<ReassociationIndices> reassociation, |
| ArrayRef<OpFoldResult> reshapeSourceShape, const DimAndIndex &dimAndIndex) { |
| const auto &[dim, indexValue] = dimAndIndex; |
| SmallVector<OpFoldResult> basis; |
| for (int64_t i : reassociation[dim]) |
| basis.push_back(reshapeSourceShape[i]); |
| auto delinearized = |
| AffineDelinearizeIndexOp::create(b, loc, indexValue, basis); |
| return delinearized->getResults(); |
| } |
| |
| FailureOr<ExtractSliceFromCollapseHelper> |
| tensor::ExtractSliceFromCollapseHelper::create( |
| OpBuilder &b, tensor::CollapseShapeOp collapseOp, |
| tensor::ExtractSliceOp extractOp) { |
| if (extractOp.getSource().getDefiningOp<tensor::CollapseShapeOp>() != |
| collapseOp) |
| return failure(); |
| SmallVector<Range> ranges; |
| ranges.reserve(extractOp.getSourceType().getRank()); |
| for (const auto &[o, s, st] : |
| llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(), |
| extractOp.getMixedStrides())) { |
| ranges.push_back({o, s, st}); |
| } |
| return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges); |
| } |
| |
| FailureOr<ExtractSliceFromCollapseHelper> |
| tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b, |
| tensor::CollapseShapeOp op, |
| ArrayRef<Range> sliceParams) { |
| // Don't perform this pattern if the collapse op can be simplified by |
| // a rank-reducing extract slice. |
| if (succeeded(mlir::getSimplifyCollapseShapeWithRankReducingSliceInfo( |
| op.getSrcType(), op.getReassociationIndices()))) |
| return failure(); |
| |
| // Materialize the output shape of the collapse_shape operation. This will |
| // create IR describing the output shape in terms of the input shape. |
| ReifiedRankedShapedTypeDims reifiedShapes; |
| if (failed(reifyResultShapes(b, op, reifiedShapes))) |
| return failure(); |
| SmallVector<OpFoldResult> &collapseShapeOutputShape = reifiedShapes[0]; |
| SmallVector<ReassociationIndices> reassociationIndices = |
| op.getReassociationIndices(); |
| |
| // Determine which of the CollapseShapeOp's result dimensions are sliced |
| // and/or linearized. |
| llvm::SmallBitVector linearizedDimensions = |
| getLinearizedDimensions(reassociationIndices); |
| llvm::SmallBitVector slicedDimensions = |
| getSlicedDimensions(collapseShapeOutputShape, sliceParams); |
| |
| auto collapseShapeInputShape = |
| tensor::getMixedSizes(b, op.getLoc(), op.getSrc()); |
| |
| SmallVector<Value> tileSizes; |
| for (unsigned i = 0; i < sliceParams.size(); i++) { |
| if (slicedDimensions[i] && linearizedDimensions[i]) |
| tileSizes.push_back( |
| getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size)); |
| } |
| |
| return ExtractSliceFromCollapseHelper( |
| op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams, |
| linearizedDimensions, slicedDimensions, tileSizes); |
| } |
| |
| std::pair<Value, SmallVector<Range>> |
| tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody( |
| OpBuilder &builder, Location loc, ValueRange tileInductionVars) { |
| // Create the helper class for forming the slice parameters. |
| const SmallVector<ReassociationIndices> reassociationIndices = |
| collapseShapeOp.getReassociationIndices(); |
| SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape, |
| collapseShapeOutputShape, sliceParams); |
| |
| // Get the indices of the tiled dims (linearized by the collapse_shape |
| // and sliced by the extract_slice) invert the index spaces |
| // transformations. |
| SmallVector<ValueRange> multiIndices; |
| unsigned loopIdx = 0; |
| for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) { |
| if (linearizedDimensions[i] && slicedDimensions[i]) { |
| DimAndIndex tb = |
| invertSliceIndexing(builder, loc, sliceParams, |
| std::make_tuple(i, tileInductionVars[loopIdx++])); |
| multiIndices.push_back(invertCollapseShapeIndexing( |
| builder, loc, reassociationIndices, collapseShapeInputShape, tb)); |
| } |
| } |
| |
| SmallVector<Range> extractParams = |
| helper.getExtractSliceParams(builder.getContext(), multiIndices); |
| |
| Value subTileResult = tensor::ExtractSliceOp::create( |
| builder, loc, collapseShapeOp.getSrc(), extractParams); |
| |
| SmallVector<Range> insertParams = |
| helper.getInsertSliceParams(builder.getContext(), tileInductionVars); |
| |
| // Collapse the dimensions of the source slice back down. |
| Value collapsedResult = tensor::CollapseShapeOp::create( |
| builder, loc, subTileResult, reassociationIndices); |
| return std::make_pair(collapsedResult, insertParams); |
| } |
| |
| FailureOr<Operation *> |
| tensor::simplifyCollapseShapeWithRankReducingExtractSlice( |
| tensor::CollapseShapeOp op, RewriterBase &rewriter) { |
| SmallVector<ReassociationIndices> reassociationIndices = |
| op.getReassociationIndices(); |
| RankedTensorType sourceType = op.getSrcType(); |
| FailureOr<CollapseShapeRankReducingSliceSimplificationInfo> info = |
| getSimplifyCollapseShapeWithRankReducingSliceInfo(sourceType, |
| reassociationIndices); |
| if (failed(info)) |
| return failure(); |
| |
| // Create the rank-reducing extract slice op. |
| auto zero = rewriter.getIndexAttr(0); |
| auto one = rewriter.getIndexAttr(1); |
| SmallVector<OpFoldResult> offsets(sourceType.getRank(), zero); |
| SmallVector<OpFoldResult> sizes = |
| tensor::getMixedSizes(rewriter, op.getLoc(), op.getSrc()); |
| SmallVector<OpFoldResult> strides(sourceType.getRank(), one); |
| auto sliceOp = tensor::ExtractSliceOp::create( |
| rewriter, op.getLoc(), info->sliceResultType, op.getSrc(), offsets, sizes, |
| strides); |
| |
| if (!info->newReassociationIndices.has_value()) { |
| rewriter.replaceOp(op, sliceOp.getResult()); |
| return sliceOp.getOperation(); |
| } |
| |
| return rewriter |
| .replaceOpWithNewOp<tensor::CollapseShapeOp>( |
| op, sliceOp.getResult(), *info->newReassociationIndices) |
| .getOperation(); |
| } |