blob: 448e677e7ac130fe40bd81bcd9d23ce973bf64fc [file] [log] [blame]
//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements linalg fusion on tensors
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
using namespace linalg;
//===----------------------------------------------------------------------===//
// StructuredOp specific helpers.
//===----------------------------------------------------------------------===//
/// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
/// The slice defines a hyper rectangular iteration space and fusing the
/// producer is always possible. However, depending on the consumer indexing
/// map, not all slice elements may be consumed and the tiles may overlap. In
/// these cases, fusion introduces redundant computation.
SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
ArrayRef<int64_t> tiledLoopDims) {
// Get the consumer operand indexing map.
LinalgOp consumerOp = consumerOperand->getOwner();
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
// Search the slice dimensions tiled by a tile loop dimension.
DenseSet<int64_t> tiledSliceDims;
for (auto en : enumerate(indexingMap.getResults())) {
for (auto tiledLoopDim : tiledLoopDims) {
if (en.value().isFunctionOfDim(tiledLoopDim))
tiledSliceDims.insert(en.index());
}
}
return {tiledSliceDims.begin(), tiledSliceDims.end()};
}
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
/// along the `tiledSliceDims` and clone the producer. Consider the case of
/// fusion of an output tensor:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = consumer ins(...) outs(%1)
/// ```
/// When consumer is tiled, %1 appears in the loop iter_args:
/// ```
/// %1 = producer ins(...) outs(%0)
/// %2 = scf.for ... iter_args(%1) .. (%bbarg) {
/// %t1 = tensor.extract_slice %bbarg[..]
/// %t2 = consumer ins(...) outs(%t1)
/// %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0):
/// ```
/// %2 = scf.for ... iter_args(%0) .. (%bbarg) {
/// %t0 = tensor.extract_slice %bbarg[..]
/// %t1 = producer ins(...) outs(%t0)
/// %t2 = consumer ins(...) outs(%t1)
/// %r = tensor.insert_slice %t2, %bbarg[...]
/// }
/// ```
/// This transformation is only valid if %bbarg is exclusively used by the
/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the
/// `fuseProducer` method.
/// TODO: instead of check and failure, insert new iter_args each time a
/// producer is fused into a consumer and fold away unused iter_args.
static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
tensor::ExtractSliceOp sliceOp,
ArrayRef<int64_t> tiledSliceDims,
OpOperand *iterArg) {
// Clone the producer after `sliceOp` since the slice may be reused to pass in
// the producer result.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(sliceOp);
// Get the producer.
LinalgOp producerOp = producerResult.getOwner();
Location loc = producerOp.getLoc();
// Obtain the `producerOp` loop bounds and the `sliceOp` ranges.
SmallVector<Value> producerLoopBounds;
transform(producerOp.createLoopRanges(b, loc),
std::back_inserter(producerLoopBounds),
[](Range range) { return range.size; });
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
// Get the producer result indexing map.
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
producerOp.getOutputOperand(producerResult.getResultNumber()));
// Tile the producer operands given the `sliceOp` ranges. Iterate the
// `tiledSliceDims` and store the tile offset and size for the tiled slice
// dimension. Assumes the mapping from slice dimensions to producer loops is a
// permutation.
auto zero = b.create<ConstantIndexOp>(loc, 0);
SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
for (int64_t tiledSliceDim : tiledSliceDims) {
AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
assert(result.isa<AffineDimExpr>() &&
"expect producer indexing map is a projected permutation");
int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
}
erase_value(tileIvs, nullptr);
SmallVector<Value> tiledOperands = producerOp.getInputAndOutputOperands();
tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs,
tileSizes, producerLoopBounds);
// Output fusion has to update the iteration arguments of the tile loop nest.
// In particular, the iteration argument of the outermost tile loop needs to
// be set to the producer output instead of the producer result and `clonedOp`
// shall use the existing `sliceOp` result instead of the tiled producer
// output operand.
if (iterArg) {
OpOperand *outputOperand =
producerOp.getOutputOperand(producerResult.getResultNumber());
iterArg->set(outputOperand->get());
tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult();
}
// Clone the producer using the tiled producer operands.
TypeRange resultTypes = ValueRange(tiledOperands)
.take_back(producerOp.getNumOutputs())
.getTypes();
LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands);
// Shift all IndexOp results by the tile offset.
addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs);
return clonedOp;
}
//===----------------------------------------------------------------------===//
// TileLoopNest specific helpers.
//===----------------------------------------------------------------------===//
bool TileLoopNest::isEmpty() { return loopOps.empty(); }
bool TileLoopNest::isValid() {
// Check if the number of `tileLoopOps` and `tileLoopDims` match.
if (loopOps.size() != loopDims.size())
return false;
// Check if the innermost tile loop is the parent of `tiledOp`.
if (rootOp->getParentOp() != loopOps.back())
return false;
// Check if the tile loops are directly nested.
return std::adjacent_find(loopOps.begin(), loopOps.end(),
[](Operation *op1, Operation *op2) {
return op1 != op2->getParentOp();
}) == loopOps.end();
}
SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
assert(bbArg && "expect the block argument to be non-zero");
SmallVector<BlockArgument> bbArgs;
// Search all tile loop block arguments from inner to outer.
for (auto tileLoop : reverse(loopOps)) {
if (bbArg.getOwner()->getParentOp() != tileLoop)
return {};
bbArgs.push_back(bbArg);
OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg);
bbArg = iterArg->get().dyn_cast<BlockArgument>();
}
// Reverse the block arguments to order them from outer to inner.
return {bbArgs.rbegin(), bbArgs.rend()};
}
OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
// Search all block arguments and return the matching iteration argument.
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
if (bbArgs.size() != loopOps.size())
return nullptr;
return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
}
bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
tensor::ExtractSliceOp sliceOp) {
// Check the innermost block argument is either used by the ExtractSliceOp
// `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses
// conservatively.
for (Operation *op : bbArg.getUsers()) {
if (!isa<tensor::DimOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(op))
return false;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op)) {
if (extractSliceOp != sliceOp)
return false;
}
if (auto insertSliceOp = dyn_cast<tensor::InsertSliceOp>(op)) {
SetVector<Operation *> backwardSlice;
getBackwardSlice(insertSliceOp.source(), &backwardSlice,
[](Operation *op) {
return isa<LinalgOp, tensor::InsertSliceOp>(op);
});
if (backwardSlice.empty() || backwardSlice.front() != sliceOp)
return false;
}
}
// Check the block arguments, except for the innermost one, have one use.
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
return !all_of(bbArgs, [&](BlockArgument bbArg) {
return bbArg.hasOneUse() || bbArg == bbArgs.back();
});
}
LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
// Exit if all tile sizes are zero.
if (tileSizes.size() == static_cast<size_t>(count(tileSizes, 0)))
return success();
// Tile the root operation.
LinalgTilingOptions tilingOptions;
tilingOptions = tilingOptions
.setInterchange(SmallVector<unsigned>(
tileInterchange.begin(), tileInterchange.end()))
.setTileSizes(tileSizes)
.setLoopType(LinalgTilingLoopType::Loops);
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
// Exit if tiling the root operation fails.
if (!tiledRootOp.hasValue())
return failure();
// Replace all uses of the root operation if it has been tiled before. All
// uses of the original untiled root operation are updated by the calling pass
// or pattern.
if (!isEmpty())
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
// Update the root operation and append the loops and tile loop dimensions.
rootOp = tiledRootOp->op;
loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
for (auto en : enumerate(tileSizes)) {
// Copy only the tiled loop dimensions with non-zero tile size.
if (en.value() == 0)
continue;
loopDims.push_back(tileInterchange[en.index()]);
}
assert(isValid() && "expect tile loop nest to be valid after tiling");
return success();
}
FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
OpOperand *rootOpOperand) {
assert(rootOpOperand->getOwner() == rootOp &&
"expect the root op to be the owner of the operand to fuse");
assert(this->isValid() &&
"expect the tile loop nest to satisfy all invariants");
// Check the tile loop nest is non-empty.
if (isEmpty())
return failure();
// Check `rootOpOperand` is defined by an ExtractSliceOp.
auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
if (!sliceOp)
return failure();
// Check `sliceOp` is tiled by the tile loop nest.
if (sliceOp->getParentOp() != rootOp->getParentOp())
return failure();
// Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr;
auto producerResult = sliceOp.source().dyn_cast<OpResult>();
if (auto bbArg = sliceOp.source().dyn_cast<BlockArgument>()) {
iterArg = getTiedIterArg(bbArg);
// Check the iteration argument may be used to pass in the producer output.
if (!iterArg || hasOtherUses(bbArg, sliceOp))
return failure();
producerResult = iterArg->get().dyn_cast<OpResult>();
}
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
return failure();
// Compute the tiled producer slice dimensions given the tiled root operation
// loop dimensions `loopDims`.
SmallVector<int64_t> tiledSliceDims =
getTiledSliceDims(rootOpOperand, loopDims);
if (tiledSliceDims.empty())
return failure();
// Tile the producer operands and clone the producer in place of `sliceOp`.
LinalgOp clonedOp =
getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
// Cast the `clonedOp` result to gap type mismatches before canonicalization.
Type consumerOperandType = rootOpOperand->get().getType();
Value newResult = clonedOp->getResult(producerResult.getResultNumber());
if (newResult.getType() != consumerOperandType) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointAfter(clonedOp);
newResult = b.create<tensor::CastOp>(producerResult.getLoc(),
consumerOperandType, newResult);
}
// Replace the `sliceOp` uses except for the `clonedOp` output uses.
sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp);
return clonedOp;
}
ValueRange TileLoopNest::getRootOpReplacementResults() {
assert(!isEmpty() && "expect tile loop nest to be non-empty");
return loopOps.front()->getOpResults();
}
//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
FailureOr<TileLoopNest>
mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> tileInterchange) {
assert(tileSizes.size() == tileInterchange.size() &&
"expect the number of tile sizes and interchange dims to match");
assert(isPermutation(tileInterchange) &&
"expect tile interchange is a permutation");
// Create an empty tile loop nest.
TileLoopNest tileLoopNest(consumerOp);
// Search the number of outer parallel loops to separate them from possible
// inner reduction dimensions.
SmallVector<StringAttr> iterTypes =
llvm::to_vector<6>(consumerOp.iterator_types().getAsRange<StringAttr>());
applyPermutationToVector(iterTypes, tileInterchange);
auto *it = find_if(iterTypes, [&](StringAttr iterType) {
return !isParallelIterator(iterType);
});
int64_t split = std::distance(iterTypes.begin(), it);
// Tile the outer parallel loops and fuse the output operands.
SmallVector<int64_t> outerTileSizes;
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
outerTileSizes.append(tileSizes.size() - split, 0);
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
return failure();
for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
(void)tileLoopNest.fuseProducer(b, opOperand);
// Tile the remaining loops and fuse the input operands.
SmallVector<int64_t> innerTileSizes;
innerTileSizes.append(split, 0);
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
return failure();
SmallVector<OpOperand *> inputOperands =
tileLoopNest.getRootOp().getInputOperands();
for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
(void)tileLoopNest.fuseProducer(b, opOperand);
return tileLoopNest;
}
namespace {
struct LinalgTileAndFuseTensorOps
: public LinalgTileAndFuseTensorOpsBase<LinalgTileAndFuseTensorOps> {
void notifyFailure(StringRef message) {
llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n";
signalPassFailure();
}
void runOnFunction() override {
FuncOp funcOp = getFunction();
OpBuilder b(funcOp.getContext());
// Heuristic to find a good operation to tile and start fusion. Walk all
// operations and select the one with the maximal backward slice of fusion
// candidates.
LinalgOp rootOp = nullptr;
int64_t numFusionCandidates = -1;
funcOp.walk([&](LinalgOp linalgOp) {
SetVector<Operation *> backwardSlice;
getBackwardSlice(linalgOp, &backwardSlice);
int64_t backwardSliceSize = count_if(
backwardSlice, [](Operation *op) { return isa<LinalgOp>(op); });
if (backwardSliceSize > numFusionCandidates) {
rootOp = linalgOp;
numFusionCandidates = backwardSliceSize;
}
});
if (!rootOp)
return notifyFailure("expect to find a root operation");
// Check `tileSizes` contains a tile size for every `rootOp` loop dimension.
if (tileSizes.size() < rootOp.getNumLoops())
return notifyFailure("expect #tile sizes >= #loops");
// Check `tileInterchange` contains no entries or as many as `tileSizes`.
if (!tileInterchange.empty() &&
tileInterchange.size() != tileSizes.size()) {
return notifyFailure(
"expect the number of tile sizes and interchange dims to match");
}
// Copy the `tileSizes` and `tileInterchange` prefixes needed to tile
// `rootOp` or use the identity interchange if `tileInterchange` is empty.
SmallVector<int64_t> rootTileSizes(
tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops());
SmallVector<int64_t> rootInterchange =
tileInterchange.empty()
? llvm::to_vector<6>(llvm::seq<int64_t>(0, rootOp.getNumLoops()))
: SmallVector<int64_t>(tileInterchange.begin(),
tileInterchange.begin() +
rootOp.getNumLoops());
// Check `rootInterchange` is a permutation of the `rootOp` loop dimensions.
// It has to be a permutation since the tiling cannot tile the same loop
// dimension multiple times.
if (!isPermutation(rootInterchange))
return notifyFailure(
"expect the tile interchange permutes the root loops");
// Tile `rootOp` and fuse its producers.
FailureOr<TileLoopNest> tileLoopNest =
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
if (failed(tileLoopNest))
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");
// Replace all uses of the tiled loop operation.
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::createLinalgTileAndFuseTensorOpsPass() {
return std::make_unique<LinalgTileAndFuseTensorOps>();
}