| //===- FusionOnTensors.cpp - Implementation of linalg Fusion --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements linalg fusion on tensors |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "mlir/Analysis/SliceAnalysis.h" |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" |
| #include "mlir/Dialect/Linalg/Passes.h" |
| #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| #include "mlir/Dialect/Linalg/Utils/Utils.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/Support/LLVM.h" |
| |
| using namespace mlir; |
| using namespace linalg; |
| |
| //===----------------------------------------------------------------------===// |
| // StructuredOp specific helpers. |
| //===----------------------------------------------------------------------===// |
| |
| /// Returns the tiled slice dimensions given the tiled consumer loop dimensions. |
| /// The slice defines a hyper rectangular iteration space and fusing the |
| /// producer is always possible. However, depending on the consumer indexing |
| /// map, not all slice elements may be consumed and the tiles may overlap. In |
| /// these cases, fusion introduces redundant computation. |
| static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand, |
| ArrayRef<int64_t> tiledLoopDims) { |
| // Get the consumer operand indexing map. |
| LinalgOp consumerOp = consumerOperand->getOwner(); |
| AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand); |
| |
| // Search the slice dimensions tiled by a tile loop dimension. |
| DenseSet<int64_t> tiledSliceDimIndices; |
| for (auto en : enumerate(indexingMap.getResults())) { |
| for (auto tiledLoopDim : tiledLoopDims) { |
| if (en.value().isFunctionOfDim(tiledLoopDim)) |
| tiledSliceDimIndices.insert(en.index()); |
| } |
| } |
| return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()}; |
| } |
| |
| /// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions |
| /// of the producer result slice returns the tiled producer loop dimensions. |
| /// Example: |
| /// ``` |
| /// %res = linalg.fill(%cst, %input) |
| /// scf.for %i |
| /// scf.for %j |
| /// %slice = tensor.extract_slice %res[%i, %j] |
| /// ``` |
| /// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1]. |
| static SmallVector<int64_t> |
| getTiledProducerLoops(OpResult producerResult, |
| ArrayRef<int64_t> tiledSliceDimIndices) { |
| LinalgOp producerOp = producerResult.getOwner(); |
| |
| // Get the indexing map of the `producerOp` output operand that matches |
| // ´producerResult´. |
| AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( |
| producerOp.getOutputOperand(producerResult.getResultNumber())); |
| |
| // Keep only the tiled result slice dimensions of `producerIndexingMap`. |
| AffineMap tiledProducerIndexingSubMap = |
| producerIndexingMap.getSubMap(SmallVector<unsigned>( |
| tiledSliceDimIndices.begin(), tiledSliceDimIndices.end())); |
| |
| // Compute the producer loop indices mapped to the tiled result slice |
| // dimensions. As the output indexing map of structured operations are |
| // projected permutations, `tiledProducerIndexingSubMap` has to be a |
| // projected permutation as well. We can thus obtain the producer loop indices |
| // by getting the positions of the result dimensions. |
| // Example: |
| // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2]. |
| assert(tiledProducerIndexingSubMap.isProjectedPermutation() && |
| "expect slice and producer loop dimensions map one-to-one"); |
| SmallVector<int64_t> tiledProducerLoopIndices; |
| transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()), |
| std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) { |
| return tiledProducerIndexingSubMap.getDimPosition(idx); |
| }); |
| |
| return tiledProducerLoopIndices; |
| } |
| |
| /// Returns the producer fused in place of `sliceOp`. Tile the producer operands |
| /// along the `tiledSliceDimIndices` and clone the producer. Consider the case |
| /// of fusion of an output tensor: |
| /// ``` |
| /// %1 = producer ins(...) outs(%0) |
| /// %2 = consumer ins(...) outs(%1) |
| /// ``` |
| /// When consumer is tiled, %1 appears in the loop iter_args: |
| /// ``` |
| /// %1 = producer ins(...) outs(%0) |
| /// %2 = scf.for ... iter_args(%1) .. (%bbarg) { |
| /// %t1 = tensor.extract_slice %bbarg[..] |
| /// %t2 = consumer ins(...) outs(%t1) |
| /// %r = tensor.insert_slice %t2, %bbarg[...] |
| /// } |
| /// ``` |
| /// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): |
| /// ``` |
| /// %2 = scf.for ... iter_args(%0) .. (%bbarg) { |
| /// %t0 = tensor.extract_slice %bbarg[..] |
| /// %t1 = producer ins(...) outs(%t0) |
| /// %t2 = consumer ins(...) outs(%t1) |
| /// %r = tensor.insert_slice %t2, %bbarg[...] |
| /// } |
| /// ``` |
| /// This transformation is only valid if %bbarg is exclusively used by the |
| /// output ExtractSliceOp / InsertSliceOp pair, which is checked by the |
| /// `fuseProducer` method. |
| /// TODO: instead of check and failure, insert new iter_args each time a |
| /// producer is fused into a consumer and fold away unused iter_args. |
| static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, |
| tensor::ExtractSliceOp sliceOp, |
| ArrayRef<int64_t> tiledSliceDimIndices, |
| ArrayRef<int64_t> tiledProducerLoopIndices, |
| OpOperand *iterArg) { |
| // Clone the producer after `sliceOp` since the slice may be reused to pass in |
| // the producer result. |
| OpBuilder::InsertionGuard guard(b); |
| b.setInsertionPointAfter(sliceOp); |
| |
| // Get the producer. |
| LinalgOp producerOp = producerResult.getOwner(); |
| Location loc = producerOp.getLoc(); |
| |
| // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. |
| SmallVector<Value> producerLoopBounds; |
| transform(producerOp.createLoopRanges(b, loc), |
| std::back_inserter(producerLoopBounds), |
| [](Range range) { return range.size; }); |
| SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); |
| |
| // Tile the producer operands given the `sliceOp` ranges. Iterate the |
| // `tiledSliceDimIndices` and store the tile offset and size for the tiled |
| // slice dimension. |
| auto zero = b.create<arith::ConstantIndexOp>(loc, 0); |
| SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr); |
| SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero); |
| SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr); |
| for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) { |
| int64_t tiledSliceDim = std::get<0>(it); |
| int64_t tiledProducerLoop = std::get<1>(it); |
| tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; |
| tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; |
| allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; |
| } |
| erase_value(tileIvs, nullptr); |
| SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands(); |
| tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, |
| tileSizes, producerLoopBounds); |
| |
| // Output fusion has to update the iteration arguments of the tile loop nest. |
| // In particular, the iteration argument of the outermost tile loop needs to |
| // be set to the producer output instead of the producer result and `clonedOp` |
| // shall use the existing `sliceOp` result instead of the tiled producer |
| // output operand. |
| if (iterArg) { |
| OpOperand *outputOperand = |
| producerOp.getOutputOperand(producerResult.getResultNumber()); |
| iterArg->set(outputOperand->get()); |
| tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); |
| } |
| |
| // Clone the producer using the tiled producer operands. |
| TypeRange resultTypes = ValueRange(tiledOperands) |
| .take_back(producerOp.getNumOutputs()) |
| .getTypes(); |
| LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); |
| |
| // Shift all IndexOp results by the tile offset. |
| addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); |
| |
| return clonedOp; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TileLoopNest specific helpers. |
| //===----------------------------------------------------------------------===// |
| |
| bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); } |
| |
| bool TileLoopNest::isValid() { |
| // Check if `rootOp` has been tiled at least once. |
| if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0) |
| return false; |
| |
| // Check if the number of loop operations and dimensions match. |
| if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size()) |
| return false; |
| |
| // Check if the innermost tile loop is the parent of `tiledOp`. |
| if (rootOp->getParentOp() != tileLoopOps.back()) |
| return false; |
| |
| // Check if the tile loops are directly nested. |
| return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), |
| [](Operation *op1, Operation *op2) { |
| return op1 != op2->getParentOp(); |
| }) == tileLoopOps.end(); |
| } |
| |
| SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { |
| assert(bbArg && "expect the block argument to be non-zero"); |
| SmallVector<BlockArgument> bbArgs; |
| |
| // Search all tile loop block arguments from inner to outer. |
| for (auto tileLoop : reverse(tileLoopOps)) { |
| if (bbArg.getOwner()->getParentOp() != tileLoop) |
| return {}; |
| bbArgs.push_back(bbArg); |
| OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); |
| bbArg = iterArg->get().dyn_cast<BlockArgument>(); |
| } |
| |
| // Reverse the block arguments to order them from outer to inner. |
| return {bbArgs.rbegin(), bbArgs.rend()}; |
| } |
| |
| OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { |
| // Search all block arguments and return the matching iteration argument. |
| SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); |
| if (bbArgs.size() != tileLoopOps.size()) |
| return nullptr; |
| return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); |
| } |
| |
| bool TileLoopNest::hasOtherUses(BlockArgument bbArg, |
| tensor::ExtractSliceOp sliceOp) { |
| // Check the innermost block argument is either used by the ExtractSliceOp |
| // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses |
| // conservatively. |
| for (Operation *op : bbArg.getUsers()) { |
| if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op)) |
| return false; |
| if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) { |
| if (extractSliceOp != sliceOp) |
| return false; |
| } |
| if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) { |
| SetVector<Operation *> backwardSlice; |
| getBackwardSlice(insertSliceOp.source(), &backwardSlice, |
| [](Operation *op) { |
| return isa<LinalgOp, tensor::InsertSliceOp>(op); |
| }); |
| if (backwardSlice.empty() || backwardSlice.front() != sliceOp) |
| return false; |
| } |
| } |
| |
| // Check the block arguments, except for the innermost one, have one use. |
| SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg); |
| return !all_of(bbArgs, [&](BlockArgument bbArg) { |
| return bbArg.hasOneUse() || bbArg == bbArgs.back(); |
| }); |
| } |
| |
| LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, |
| ArrayRef<int64_t> tileSizes, |
| ArrayRef<int64_t> tileInterchange) { |
| // Exit if all tile sizes are zero. |
| if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0))) |
| return success(); |
| |
| // Tile the root operation. |
| LinalgTilingOptions tilingOptions; |
| tilingOptions = tilingOptions |
| .setInterchange(SmallVector<unsigned>( |
| tileInterchange.begin(), tileInterchange.end())) |
| .setTileSizes(tileSizes) |
| .setLoopType(LinalgTilingLoopType::Loops); |
| Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions); |
| |
| // Exit if tiling the root operation fails. |
| if (!tiledRootOp.hasValue()) |
| return failure(); |
| |
| // Replace all uses of the root operation if it has been tiled before. All |
| // uses of the original untiled root operation are updated by the calling pass |
| // or pattern. |
| if (!isEmpty()) |
| rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); |
| |
| // Transfer the stored `rootOp` loop dimensions if it has been tiled before. |
| if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) { |
| tiledRootAndFusedOpsLoops[tiledRootOp->op] = |
| tiledRootAndFusedOpsLoops[rootOp]; |
| } |
| |
| // Update the root operation and append the loops and tile loop dimensions. |
| rootOp = tiledRootOp->op; |
| tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); |
| for (auto en : enumerate(tileSizes)) { |
| // Copy only the tiled loop dimensions with non-zero tile size. |
| if (en.value() == 0) |
| continue; |
| tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]); |
| } |
| assert(isValid() && "expect tile loop nest to be valid after tiling"); |
| return success(); |
| } |
| |
| FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b, |
| OpOperand *consumerOpOperand) { |
| // Check if the consumer has been tiled before. For example, it may not have |
| // been tiled if the outermost tile loop is a reduction loop. |
| if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) |
| return failure(); |
| |
| assert(this->isValid() && |
| "expect the tile loop nest to satisfy all invariants"); |
| |
| // Check the tile loop nest is non-empty. |
| if (isEmpty()) |
| return failure(); |
| |
| // Check `consumerOpOperand` is defined by an ExtractSliceOp. |
| auto sliceOp = |
| consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); |
| if (!sliceOp) |
| return failure(); |
| |
| // Check `sliceOp` and `consumerOp` are in the same block. |
| LinalgOp consumerOp = consumerOpOperand->getOwner(); |
| if (sliceOp->getBlock() != rootOp->getBlock() || |
| consumerOp->getBlock() != rootOp->getBlock()) |
| return failure(); |
| |
| // Check if the producer is a LinalgOp possibly passed by iteration argument. |
| OpOperand *iterArg = nullptr; |
| auto producerResult = sliceOp.source().dyn_cast<OpResult>(); |
| if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) { |
| iterArg = getTiedIterArg(bbArg); |
| // Check the iteration argument may be used to pass in the producer output. |
| if (!iterArg || hasOtherUses(bbArg, sliceOp)) |
| return failure(); |
| producerResult = iterArg->get().dyn_cast<OpResult>(); |
| } |
| if (!producerResult || !isa<LinalgOp>(producerResult.getOwner())) |
| return failure(); |
| |
| // Compute the tiled producer slice dimensions given the tiled consumer loops. |
| SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims( |
| consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]); |
| if (tiledSliceDimIndices.empty()) |
| return failure(); |
| |
| // Compute the tiled producer loop indices. |
| SmallVector<int64_t> tiledProducerLoopIndices = |
| getTiledProducerLoops(producerResult, tiledSliceDimIndices); |
| |
| // Tile the producer operands and clone the producer in place of `sliceOp`. |
| LinalgOp clonedOp = |
| getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices, |
| tiledProducerLoopIndices, iterArg); |
| tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices; |
| |
| // Cast the `clonedOp` result to gap type mismatches before canonicalization. |
| Type consumerOperandType = consumerOpOperand->get().getType(); |
| Value newResult = clonedOp->getResult(producerResult.getResultNumber()); |
| if (newResult.getType() != consumerOperandType) { |
| OpBuilder::InsertionGuard guard(b); |
| b.setInsertionPointAfter(clonedOp); |
| newResult = b.create<tensor::CastOp>(producerResult.getLoc(), |
| consumerOperandType, newResult); |
| } |
| |
| // Replace the `sliceOp` uses except for the `clonedOp` output uses. |
| sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); |
| return clonedOp; |
| } |
| |
| ValueRange TileLoopNest::getRootOpReplacementResults() { |
| assert(!isEmpty() && "expect tile loop nest to be non-empty"); |
| return tileLoopOps.front()->getOpResults(); |
| } |
| |
| SmallVector<LinalgOp> TileLoopNest::getAllTiledAndFusedOps() { |
| SmallVector<LinalgOp> result; |
| for (const auto &kvp : tiledRootAndFusedOpsLoops) { |
| auto linalgOp = dyn_cast<LinalgOp>(kvp.getFirst()); |
| assert(linalgOp && |
| "expect all tiled and fused operations are linalg operations"); |
| result.push_back(linalgOp); |
| } |
| return result; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Tile and fuse entry-points. |
| //===----------------------------------------------------------------------===// |
| |
| FailureOr<TileLoopNest> |
| mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, |
| ArrayRef<int64_t> tileSizes, |
| ArrayRef<int64_t> tileInterchange) { |
| assert(tileSizes.size() == tileInterchange.size() && |
| "expect the number of tile sizes and interchange dims to match"); |
| assert(isPermutation(tileInterchange) && |
| "expect tile interchange is a permutation"); |
| |
| // Create an empty tile loop nest. |
| TileLoopNest tileLoopNest(consumerOp); |
| |
| // Search the number of outer parallel loops to separate them from possible |
| // inner reduction dimensions. |
| SmallVector<StringAttr> iterTypes = |
| llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>()); |
| applyPermutationToVector(iterTypes, tileInterchange); |
| auto *it = find_if(iterTypes, [&](StringAttr iterType) { |
| return !isParallelIterator(iterType); |
| }); |
| int64_t split = std::distance(iterTypes.begin(), it); |
| |
| // Helper to fuse the producers greedily using a queue of fusion candidates. |
| auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) { |
| SmallVector<OpOperand *> candidates(operands.begin(), operands.end()); |
| while (!candidates.empty()) { |
| FailureOr<LinalgOp> fusedProducer = |
| tileLoopNest.fuseProducer(b, candidates.pop_back_val()); |
| if (failed(fusedProducer)) |
| continue; |
| candidates.append(fusedProducer->getInputAndOutputOperands()); |
| } |
| }; |
| |
| // Tile the outer parallel loops and fuse the output operands. |
| SmallVector<int64_t> outerTileSizes; |
| outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); |
| outerTileSizes.append(tileSizes.size() - split, 0); |
| if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) |
| return failure(); |
| fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); |
| |
| // Tile the remaining loops and fuse the input operands. |
| SmallVector<int64_t> innerTileSizes; |
| innerTileSizes.append(split, 0); |
| innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); |
| if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) |
| return failure(); |
| fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); |
| |
| return tileLoopNest; |
| } |