| //===- Split.cpp - Structured op splitting --------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/Attributes.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/Interfaces/TilingInterface.h" |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| |
| /// Creates a part of the given `op` split along the iteration space `dimension` |
| /// with the given `size` and an optional `offset` (default 0). Makes slices |
| /// of operands, using the input operands of the original op and the output |
| /// operands provided as `resultOperands`. Expects `offsets` and `sizes` to |
| /// define the shape of the iteration space of the original op. Returns the |
| /// split-out op as well as the output operand values updated with the partial |
| /// results produced by this op through `results`. |
| static TilingInterface |
| createSplitPart(RewriterBase &b, Location loc, TilingInterface op, |
| ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes, |
| ValueRange resultOperands, unsigned dimension, |
| OpFoldResult size, OpFoldResult offset, |
| SmallVectorImpl<Value> &results) { |
| // Iteration space of the current part. |
| SmallVector<OpFoldResult> sizesCopy = llvm::to_vector(sizes); |
| SmallVector<OpFoldResult> offsetsCopy = llvm::to_vector(offsets); |
| sizesCopy[dimension] = size; |
| offsetsCopy[dimension] = offset; |
| |
| // Create the part as it it were a single tile. |
| SmallVector<Operation *> tiled = |
| op.getTiledImplementation(b, offsetsCopy, sizesCopy); |
| assert(tiled.size() == 1 && "expected a single result from tiling"); |
| auto part = cast<TilingInterface>(tiled.front()); |
| |
| // Insert the results back and populate the `results` list. |
| for (auto i : llvm::seq<unsigned>(0, part->getNumResults())) { |
| SmallVector<OpFoldResult> resultOffsets, resultSizes; |
| if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, |
| resultOffsets, resultSizes))) |
| return nullptr; |
| SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), |
| b.getIndexAttr(1)); |
| Value inserted = b.create<tensor::InsertSliceOp>( |
| loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, |
| resultStrides); |
| results.push_back(inserted); |
| } |
| |
| return part; |
| } |
| |
| std::pair<TilingInterface, TilingInterface> |
| linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, |
| OpFoldResult splitPoint) { |
| // Compute the iteration space. |
| SmallVector<Range> iterationSpace = op.getIterationDomain(rewriter); |
| |
| // Bail out on dimension overflow. |
| if (dimension >= iterationSpace.size()) |
| return std::make_pair(op, TilingInterface()); |
| |
| SmallVector<OpFoldResult> offsets = llvm::to_vector(llvm::map_range( |
| iterationSpace, [](const Range &range) { return range.offset; })); |
| SmallVector<OpFoldResult> sizes = llvm::to_vector(llvm::map_range( |
| iterationSpace, [](const Range &range) { return range.size; })); |
| |
| // Adjust the split point so that it doesn't overflow the size. |
| AffineExpr d0, d1, d2; |
| bindDims(rewriter.getContext(), d0, d1, d2); |
| OpFoldResult minSplitPoint = makeComposedFoldedAffineMin( |
| rewriter, op.getLoc(), |
| AffineMap::inferFromExprList(ArrayRef<AffineExpr>{d0, d1 + d2}).front(), |
| {splitPoint, offsets[dimension], sizes[dimension]}); |
| |
| // Compute the size of the second part. Return early if the second part would |
| // have an empty iteration space. |
| OpFoldResult remainingSize = makeComposedFoldedAffineApply( |
| rewriter, op.getLoc(), d0 + d1 - d2, |
| {iterationSpace[dimension].offset, iterationSpace[dimension].size, |
| minSplitPoint}); |
| if (auto attr = remainingSize.dyn_cast<Attribute>()) { |
| if (attr.cast<IntegerAttr>().getValue().isZero()) |
| return {op, TilingInterface()}; |
| } |
| |
| // Compute destination tensors. |
| SmallVector<Value> destinationTensors; |
| LogicalResult destStatus = tensor::getOrCreateDestinations( |
| rewriter, op.getLoc(), op, destinationTensors); |
| (void)destStatus; |
| assert(succeeded(destStatus) && "failed to get destination tensors"); |
| |
| // Create the first part. |
| SmallVector<Value> firstResults; |
| TilingInterface firstPart = createSplitPart( |
| rewriter, op.getLoc(), op, offsets, sizes, destinationTensors, dimension, |
| minSplitPoint, iterationSpace[dimension].offset, firstResults); |
| |
| // Need to pretend that the original op now takes as operands firstResults, |
| // otherwise tiling interface implementation will take the wrong value to |
| // produce data tiles. |
| rewriter.updateRootInPlace(op, [&]() { |
| unsigned numTotalOperands = op->getNumOperands(); |
| unsigned numOutputOperands = firstResults.size(); |
| op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, |
| firstResults); |
| }); |
| |
| // Create the second part. |
| OpFoldResult totalOffset = makeComposedFoldedAffineApply( |
| rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint}); |
| SmallVector<Value> secondResults; |
| TilingInterface secondPart = |
| createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, |
| dimension, remainingSize, totalOffset, secondResults); |
| |
| // Replace the original op with the results of the two newly created ops. |
| rewriter.replaceOp(op, secondResults); |
| return {firstPart, secondPart}; |
| } |