blob: 91862d2e17d7136da616f4b61eaa8bb04d63aefd [file] [log] [blame]
//===- Tiling.cpp - Implementation of tiling using TilingInterface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the tiling using TilingInterface.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
#define DEBUG_TYPE "tile-using-interface"
using namespace mlir;
scf::SCFTilingOptions &
scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
assert(!tileSizeComputationFunction && "tile sizes already set");
auto tileSizes = llvm::to_vector(ts);
tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
return tileSizes;
};
return *this;
}
scf::SCFTilingOptions &
scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
assert(!numThreadsComputationFunction && "num tiles already set");
auto numThreads = llvm::to_vector(nt);
numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
return numThreads;
};
return *this;
}
/// Helper method to adjust the interchange vector to match the iteration
/// domain.
static SmallVector<int64_t>
fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
size_t iterationDomainSize) {
SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector);
if (filledVector.size() < iterationDomainSize) {
auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize);
filledVector.append(range.begin(), range.end());
}
if (filledVector.size() > iterationDomainSize)
filledVector.resize(iterationDomainSize);
return filledVector;
}
//===----------------------------------------------------------------------===//
// tileUsingSCF implementation.
//===----------------------------------------------------------------------===//
/// Verify the tile size options are set in a consistent manner.
static LogicalResult
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
const scf::SCFTilingOptions &options) {
// Specifying number of threads is only supported on `scf.forall` op.
if (options.numThreadsComputationFunction &&
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
return rewriter.notifyMatchFailure(
loc, "number of threads can only by specified when loop type is "
"set to use `scf.forall`");
}
// If specified, check that the interchange vector is a permutation.
if (!options.interchangeVector.empty()) {
if (!isPermutationVector(options.interchangeVector)) {
return rewriter.notifyMatchFailure(
loc, "invalid interchange vector, not a permutation of the entire "
"iteration space");
}
}
return success();
}
/// Method to instantiate the tile sizes and/or number of threads specified
/// by the user.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
ArrayRef<Range> iterationDomain,
const scf::SCFTilingOptions &options) {
OpFoldResult zero = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> tileSizes, numThreads;
size_t numLoops = iterationDomain.size();
// Check whether the number of tiles to use is specified.
if (options.numThreadsComputationFunction) {
numThreads = options.numThreadsComputationFunction(rewriter, op);
numThreads.resize(numLoops, zero);
// If the number of tiles is also specified, use that.
if (options.tileSizeComputationFunction) {
tileSizes = options.tileSizeComputationFunction(rewriter, op);
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
}
// Compute the tile sizes from the iteration domain and number
// of tiles as follows
// - niters = ceilDiv(ub - lb, step)
// - tileSize = ceilDiv(niters, numThreads)
AffineExpr s0, s1, s2;
bindSymbols(rewriter.getContext(), s0, s1, s2);
// TODO: The step here is assumed to be 1.
AffineExpr numItersExpr = (s1 - s0);
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
if (isConstantIntValue(nt, 0))
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
}
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
}
// Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
assert(options.tileSizeComputationFunction &&
"expected tile sizes to be specified");
tileSizes = options.tileSizeComputationFunction(rewriter, op);
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
}
/// Checks if any of the tiled loops are not parallel.
static void checkSafeToTileToForall(TilingInterface op,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<OpFoldResult> numThreads) {
auto iterators = op.getLoopIteratorTypes();
assert(iterators.size() == tileSizes.size() &&
"expected as many tile size values as number of loops");
assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
"when specified, expected number of threads to use for each loop");
for (auto [index, iterator, tileSize] :
llvm::enumerate(iterators, tileSizes)) {
// If num threads is specified, check that it is greater than one only for
// parallel dimensions.
if (!numThreads.empty()) {
if (std::optional<int64_t> constNumThreads =
getConstantIntValue(numThreads[index])) {
if (constNumThreads.value() > 1 &&
iterator != utils::IteratorType::parallel) {
op.emitWarning() << "tiling is not thread safe at axis #" << index;
}
}
continue;
}
if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
if (constTileSize.value() > 0 &&
iterator != utils::IteratorType::parallel) {
op.emitWarning() << "tiling is not thread safe at axis #" << index;
}
}
}
}
/// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
return false;
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
if (!sizeAsInt)
return false;
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
if (!strideAsInt)
return false;
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
}
/// Returns the bounded tile size given the current `offset`, `loopRange` and
/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
Range loopRange, OpFoldResult offset,
OpFoldResult tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
return tileSize;
if (tileDividesIterationDomain(
Range{loopRange.offset, loopRange.size, tileSize}))
return tileSize;
// The tile size to use (to avoid out of bounds access) is minimum of
// `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
// loop.
AffineExpr s0, s1, d0;
bindDims(b.getContext(), d0);
bindSymbols(b.getContext(), s0, s1);
AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
return affine::makeComposedFoldedAffineMin(
b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
}
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
/// than `iterationSize`.
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
OpFoldResult numThreads,
OpFoldResult iterationSize) {
std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
return false;
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
}
/// Compute the `OpFoldResult`s that represents the multi-dimensional
/// `offset`s and `size`s of the tile of the iteration space that the
/// innermost loop body of the generated tiled loops corresponds to.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
ArrayRef<Range> iterationDomain,
ArrayRef<OpFoldResult> tileSizes,
ArrayRef<OpFoldResult> numThreads) {
SmallVector<OpFoldResult> offsets, sizes;
int materializedLoopNum = 0;
if (!numThreads.empty()) {
AffineExpr d0, d1, s0, s1;
AffineExpr offsetExpr, residualTileSizeExpr;
bindDims(rewriter.getContext(), d0, d1);
bindSymbols(rewriter.getContext(), s0, s1);
offsetExpr = d0 + d1 * s0;
residualTileSizeExpr = s1 - (d0 + d1 * s0);
for (auto [nt, tileSize, loopRange] :
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
if (isConstantIntValue(nt, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
}
Value iv = ivs[materializedLoopNum++];
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr,
ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, residualTileSizeExpr,
{loopRange.offset, nt, tileSize, loopRange.size});
OpFoldResult size = tileSize;
if (!isConstantIntValue(residualTileSize, 0)) {
OpFoldResult sizeMinusOffsetPerThread =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
{offset, loopRange.size});
size = affine::makeComposedFoldedAffineMin(
rewriter, loc,
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
{sizeMinusOffsetPerThread, tileSize});
}
// Consider the case where the original loop was `[0, 100)`.
// If number of threads are `7`, the tile size would be computed as
// `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
// - `offset = 0 + 6 * 15 = 105`
// - `tileSize = min(15, 100 - 105) = -5`
// To avoid negative tile sizes, we need to do a further
// `nonNegativeTileSize = affine.max(0, tileSize)`.
// This `max` can be avoided if
// `offset + tileSize * (numThreads - 1) < (ub - lb)`
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
AffineMap maxMap =
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
size = affine::makeComposedFoldedAffineMax(
rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
}
offsets.push_back(offset);
sizes.push_back(size);
}
return {offsets, sizes};
} else {
for (auto [tileSize, loopRange] :
llvm::zip_equal(tileSizes, iterationDomain)) {
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
if (isConstantIntValue(tileSize, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
}
Value iv = ivs[materializedLoopNum++];
OpFoldResult offset = getAsOpFoldResult(iv);
offsets.push_back(offset);
OpFoldResult size =
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
sizes.push_back(size);
}
return {offsets, sizes};
}
}
/// Function to return the bounds of the loops to be generated.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
SmallVector<OpFoldResult>>
getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes) {
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
// No loop if the tile size is 0.
if (isConstantIntValue(tileSize, 0))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
steps.push_back(tileSize);
}
return {lbs, ubs, steps};
}
/// A function that allows returning additional yielded values during
/// `yieldTiledValuesAndReplace`.
/// - `ivs` induction variable for the loop.
/// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
/// - `tiledValues` the tiled values to return. Must be of same size as
/// `newbbArgs`, each element of this array is inserted into the corresponding
/// element in `newbbArgs`.
/// - `resultOffsets` is of the same size as `tiledValues` and represents
/// the offsets to use when inserting corresponding element from `tiledValues`
/// into the element from `newBbArgs`.
/// - `resultSizes` is of the same size as `tiledValues` and represents
/// the size of the corresponding element from `tiledValues` inserted into
/// the element from `newBbArgs`.
/// In case the method needs to return `failure()` the method is expected
/// to clean up any inserted operations.
using YieldTiledValuesFn = std::function<LogicalResult(
RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange newBbArgs,
SmallVector<Value> &tiledValues,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
/// Clones the operation and updates the destination if the operation
/// implements the `DestinationStyleOpInterface`.
static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *op,
ValueRange newDestArgs) {
Operation *clonedOp = rewriter.clone(*op);
if (newDestArgs.empty())
return clonedOp;
if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
return clonedOp;
}
/// Generate the tile-loop nest using `scf.for` operation.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - `destinationTensors` are the init values to use for the outer most loop.
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
/// most
/// loop.
/// - `loops` is an in-out parameter into which the generated loops are
/// populated.
static LogicalResult generateLoopNestUsingForOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
YieldTiledValuesFn yieldTiledValuesFn,
SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> lbs, ubs, steps;
std::tie(lbs, ubs, steps) =
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
SmallVector<Value> lbVals =
getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
SmallVector<Value> ubVals =
getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
SmallVector<Value> stepVals =
getValueOrCreateConstantIndexOp(rewriter, loc, steps);
SmallVector<Value> ivs;
for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
auto loop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
[](OpBuilder &bodyBuilder, Location bodyLoc,
Value iv, ValueRange /*iterArgs*/) {});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPointToEnd(loop.getBody());
destinationTensors = loop.getRegionIterArgs();
}
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
tiledResults, resultOffsets, resultSizes))) {
return rewriter.notifyMatchFailure(
loc, "failed to generate inner tile loop body");
}
if (loops.empty())
return success();
assert(tiledResults.size() == destinationTensors.size() &&
"Number of results of body should be equal to number of iter args");
// 6. Yield all the results of the tiled operation.
SmallVector<Value> yieldedValues;
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
loc, tiledValue, destinationTensor, resultOffset, resultSize,
resultStride);
yieldedValues.push_back(insertSlice);
}
rewriter.create<scf::YieldOp>(loc, yieldedValues);
// Add the scf.yield operations for all the outer loops.
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(MutableArrayRef(loops).drop_back(),
MutableArrayRef(loops).drop_front())) {
rewriter.setInsertionPointToEnd(
cast<scf::ForOp>(outerLoop.getOperation()).getBody());
rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults());
}
return success();
}
/// Generate the tile-loop nest using `scf.forall` operation.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - `destinationTensors` are the init values to use for the outer most loop.
/// - `mappingVector` is the mapping attributes to use for loop construction.
/// Can be empty.
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
/// most
/// loop.
/// - `loops` is an in-out parameter into which the generated loops are
/// populated.
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
std::optional<ArrayAttr> mappingAttr;
if (!mappingVector.empty())
mappingAttr = rewriter.getArrayAttr(mappingVector);
scf::ForallOp forallOp;
bool useNumThreads = !numThreads.empty();
if (useNumThreads) {
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
for (auto nt : numThreads) {
if (isConstantIntValue(nt, 0))
continue;
nonZeroNumThreads.push_back(nt);
}
forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
destinationTensors, mappingAttr);
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
std::tie(lbs, ubs, steps) =
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
destinationTensors, mappingAttr);
}
loops.push_back(forallOp);
rewriter.setInsertionPoint(forallOp.getTerminator());
destinationTensors = forallOp.getRegionOutArgs();
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
destinationTensors, tiledResults, resultOffsets,
resultSizes)))
return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
loc, tiledValue, destinationTensor, resultOffset, resultSize,
resultStride);
}
return success();
}
/// Generate the tile-loop nest using the loop construct specifed in `options`.
/// - `options`: Tiling options specified.
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - `destinationTensors` are the init values to use for the outer most loop.
/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
/// most
/// loop.
/// - `loops` is an in-out parameter into which the generated loops are
/// populated.
static LogicalResult generateLoopNest(
RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
if (llvm::all_of(tileSizes, isZeroIndex)) {
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
tiledResults, resultOffsets, resultSizes);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
destinationTensors, tiledBodyFn, loops);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
destinationTensors, tiledBodyFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
static FailureOr<SmallVector<Value>>
createInitialTensorsForTiling(RewriterBase &rewriter, TilingInterface op,
ArrayRef<OpFoldResult> tileSizes,
const scf::SCFTilingOptions &options) {
SmallVector<Value> initTensors;
Location loc = op->getLoc();
switch (options.reductionStrategy) {
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, initTensors)))
return failure();
return initTensors;
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) {
return rewriter.notifyMatchFailure(
op, "PartialReductionOuterReduction tiling strategy is only supported"
"for operations implementing PartialReductionOpInterface");
}
// Get reduction dimensions.
// TODO: PartialReductionOpInterface should really query TilingInterface
// itself and find reduction dimensions.
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
return redOp.generateInitialTensorForPartialReduction(
rewriter, loc, tileSizes, reductionDims);
}
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
}
static FailureOr<TilingResult>
getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
const scf::SCFTilingOptions &options) {
switch (options.reductionStrategy) {
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
return op.getTiledImplementation(rewriter, offsets, sizes);
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) {
return rewriter.notifyMatchFailure(
op, "PartialReductionOuterReduction tiling strategy is only "
"supported for operations "
"implementing PartialReductionOpInterface");
}
// Get reduction dimensions.
// TODO: PartialReductionOpInterface should really query TilingInterface
// itself and find reduction dimensions.
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
return redOp.tileToPartialReduction(rewriter, op.getLoc(), regionIterArg,
offsets, sizes, reductionDims);
}
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
}
static LogicalResult
getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
TilingInterface op, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &resultOffset,
SmallVector<OpFoldResult> &resultSize,
const scf::SCFTilingOptions &options) {
switch (options.reductionStrategy) {
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
return op.getResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize);
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) {
return rewriter.notifyMatchFailure(
op, "PartialReductionOuterReduction tiling strategy is only supported"
"for operations implementing PartialReductionOpInterface");
}
// Get reduction dimensions.
// TODO: PartialReductionOpInterface should really query TilingInterface
// itself and find reduction dimensions.
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
resultOffset, resultSize,
reductionDims);
}
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
}
static FailureOr<MergeResult>
mergeTilingResults(RewriterBase &rewriter, TilingInterface op,
ValueRange partialResults,
const scf::SCFTilingOptions &options) {
switch (options.reductionStrategy) {
case scf::SCFTilingOptions::ReductionTilingStrategy::FullReduction:
// No need to merge results for reduction tiling strategy.
return MergeResult{{}, partialResults};
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
if (!redOp) {
return rewriter.notifyMatchFailure(
op, "PartialReductionOuterReduction tiling strategy is only "
"supported for operations "
"implementing PartialReductionOpInterface");
}
// Get reduction dimensions.
// TODO: PartialReductionOpInterface should really query TilingInterface
// itself and find reduction dimensions.
SmallVector<int> reductionDims;
for (auto [idx, iteratorType] :
llvm::enumerate(op.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
return redOp.mergeReductions(rewriter, op.getLoc(), partialResults,
reductionDims);
}
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
}
/// Append the specified additional `newInitOperands` operands to the
/// loops existing `init` operands (or similar), and replace `loopOp` with
/// the new loop that has the additional init operands. The loop body of
/// this loop is moved over to the new loop. `yieldTiledValuesFn`
/// is called to get the new tiled values returned, and the offset
/// and sizes at which the tiled value is inserted into the
/// new region iter_args that correspond to the newly added init operands.
template <typename LoopType>
FailureOr<LoopLikeOpInterface>
yieldTiledValuesAndReplaceLoop(LoopType loopOp, RewriterBase &rewriter,
ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
}
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.for`.
template <>
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
scf::ForOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = loopOp.getLoc();
rewriter.setInsertionPoint(loopOp);
auto inits = llvm::to_vector(loopOp.getInitArgs());
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = rewriter.create<scf::ForOp>(
loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(),
inits, [](OpBuilder &, Location, Value, ValueRange) {});
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
Block *newLoopBody = newLoop.getBody();
rewriter.mergeBlocks(
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
auto yieldOp = cast<scf::YieldOp>(newLoopBody->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> tiledValues;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
ValueRange newRegionIterArgs =
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVar(),
newRegionIterArgs, tiledValues, resultOffsets,
resultSizes))) {
rewriter.eraseOp(newLoop);
return rewriter.notifyMatchFailure(loopOp, "failed to get tiled values");
}
SmallVector<Value> newYieldValues = llvm::to_vector(yieldOp.getOperands());
for (auto [tiledValue, regionIterArg, resultOffset, resultSize] :
llvm::zip_equal(tiledValues, newRegionIterArgs, resultOffsets,
resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
Value insert = rewriter.create<tensor::InsertSliceOp>(
yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize,
resultStride);
newYieldValues.push_back(insert);
}
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
rewriter.replaceOp(loopOp,
newLoop->getResults().take_front(loopOp.getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}
/// Implementation of `yieldTiledValuesAndReplaceLoop` for `scf.forall`
template <>
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>(
scf::ForallOp loopOp, RewriterBase &rewriter, ValueRange newInitOperands,
YieldTiledValuesFn yieldTiledValuesFn) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = loopOp.getLoc();
rewriter.setInsertionPoint(loopOp);
auto inits = llvm::to_vector(loopOp.getOutputs());
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = rewriter.create<scf::ForallOp>(
loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(),
loopOp.getMixedStep(), inits, loopOp.getMapping(),
[](OpBuilder &, Location, ValueRange) {});
// Move the region of the current block to the newly created op.
Block *loopBody = loopOp.getBody();
Block *newLoopBody = newLoop.getBody();
rewriter.mergeBlocks(
loopBody, newLoopBody,
newLoopBody->getArguments().take_front(loopBody->getNumArguments()));
auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
rewriter.setInsertionPoint(terminator);
SmallVector<Value> tiledValues;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
ValueRange regionIterArgs =
newLoop.getRegionIterArgs().take_back(newInitOperands.size());
if (failed(yieldTiledValuesFn(rewriter, loc, newLoop.getInductionVars(),
regionIterArgs, tiledValues, resultOffsets,
resultSizes))) {
rewriter.eraseOp(newLoop);
return rewriter.notifyMatchFailure(loopOp,
"failed to get yielded tiled values");
}
// Update the terminator.
rewriter.setInsertionPointToEnd(terminator.getBody());
for (auto [tiledValue, iterArg, resultOffset, resultSize] : llvm::zip_equal(
tiledValues, regionIterArgs, resultOffsets, resultSizes)) {
SmallVector<OpFoldResult> resultStride(resultOffset.size(),
rewriter.getIndexAttr(1));
rewriter.create<tensor::ParallelInsertSliceOp>(
terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize,
resultStride);
}
rewriter.replaceOp(loopOp,
newLoop->getResults().take_front(loopOp.getNumResults()));
return cast<LoopLikeOpInterface>(newLoop.getOperation());
}
/// Implementation of `yieldTiledValuesAndReplaceLoop` for
/// `LoopLikeOpInterface`, that just dispatches to the implementation for each
/// supported loop type.
FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop(
LoopLikeOpInterface loopLikeOp, RewriterBase &rewriter,
ValueRange newInitOperands, YieldTiledValuesFn yieldTiledValuesFn) {
return TypeSwitch<Operation *, FailureOr<LoopLikeOpInterface>>(
loopLikeOp.getOperation())
.Case<scf::ForOp, scf::ForallOp>(
[&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
return yieldTiledValuesAndReplaceLoop(
loopOp, rewriter, newInitOperands, yieldTiledValuesFn);
})
.Default([&](auto loopOp) -> FailureOr<LoopLikeOpInterface> {
return rewriter.notifyMatchFailure(loopOp, "unhandled loop type");
});
}
/// Method to add new init values to a loop nest. Updates `loops` in-place
/// with new loops that use the `newInitValues`. The outer-loops are updated
/// to yield the new result values of the inner loop. For the innermost loop,
/// the call back `getNewYields` is invoked to get the additional values to
/// yield form the innermost loop.
static LogicalResult addInitOperandsToLoopNest(
RewriterBase &rewriter, MutableArrayRef<LoopLikeOpInterface> loops,
ValueRange newInitValues, YieldTiledValuesFn getNewTiledYieldsFn) {
SmallVector<scf::ForOp> newLoops;
if (loops.empty())
return success();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loops.front());
SmallVector<Value> ivs;
for (auto &loop : loops.drop_back()) {
rewriter.setInsertionPoint(loop);
// if loops.size() > 1 we assume that scf.for is used for the loops.
auto forLoop = cast<scf::ForOp>(loop.getOperation());
// Create a new loop with the new init values for this loop.
SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs());
newInits.append(newInitValues.begin(), newInitValues.end());
auto newLoop = rewriter.create<scf::ForOp>(
forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(),
forLoop.getStep(), newInits,
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
// Merge the body of the new loop with the body of the old loops.
SmallVector<Value> sourceBlockArgs;
sourceBlockArgs.push_back(newLoop.getInductionVar());
auto newRegionIterArgs = newLoop.getRegionIterArgs();
sourceBlockArgs.append(
newRegionIterArgs.begin(),
std::next(newRegionIterArgs.begin(), forLoop.getNumResults()));
rewriter.mergeBlocks(forLoop.getBody(), newLoop.getBody(), sourceBlockArgs);
rewriter.replaceOp(
forLoop, newLoop.getResults().take_front(forLoop.getNumResults()));
loop = newLoop;
ivs.push_back(newLoop.getInductionVar());
newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
}
// Update the loop body of the innermost loop to get new yield values.
LoopLikeOpInterface innerMostLoop = loops.back();
FailureOr<LoopLikeOpInterface> newInnerMostLoop =
yieldTiledValuesAndReplaceLoop(innerMostLoop, rewriter, newInitValues,
getNewTiledYieldsFn);
if (failed(newInnerMostLoop))
return innerMostLoop.emitOpError("failed to return additional yields");
loops.back() = newInnerMostLoop.value();
// Make all other loops except the innermost loops yield the values returned
// by the inner loop.
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
auto outerForLoop = cast<scf::ForOp>(outerLoop);
auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
llvm::to_vector(outerLoopYield.getOperands());
ValueRange additionalYields =
innerLoop->getResults().take_back(newInitValues.size());
newYields.append(additionalYields.begin(), additionalYields.end());
rewriter.setInsertionPoint(outerLoopYield);
rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
}
return success();
}
/// Implementation of tiling transformation of `op` that implements the
/// `TilingInterface` using `scf.for` to iterate over the tiles.
FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
return failure();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
// 2. Materialize the tile sizes and/or number of threads;
SmallVector<OpFoldResult> tileSizes, numThreads;
std::tie(tileSizes, numThreads) =
getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
// Check if it is safe to tile. This is hold over from previous iterations
// of tile to for-all. Consider dropping it.
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
checkSafeToTileToForall(op, tileSizes, numThreads);
}
// 3. If there is an interchange specified, permute the iteration domain and
// the tile sizes.
SmallVector<int64_t> interchangeVector;
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
assert(isPermutationVector(interchangeVector) &&
"expected interchange vector to be a permutation");
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
if (!numThreads.empty())
applyPermutationToVector(numThreads, interchangeVector);
}
FailureOr<TilingResult> tilingResult;
// 4. Define the lambda function used later to generate the body of the
// innermost tiled loop.
YieldTiledValuesFn innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
SmallVector<SmallVector<OpFoldResult>> &resultSizes)
-> LogicalResult {
// 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes;
std::tie(offsets, sizes) = getTileOffsetAndSizes(
rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
// 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified.
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
applyPermutationToVector(offsets, inversePermutation);
applyPermutationToVector(sizes, inversePermutation);
}
// 5. Generate the tiled implementation within the inner most loop.
// 5a. Clone the operation within the loop body.
auto clonedOp = cast<TilingInterface>(
cloneOpAndUpdateDestinationArgs(rewriter, op, regionIterArgs));
// 5b. Early return cloned op if tiling is not happening. We can not
// return the original op because it could lead to `rewriter.replaceOp(op,
// op->getResults())` and users would get crash.
if (llvm::all_of(tileSizes, isZeroIndex)) {
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
tilingResult =
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
/*generatedSlices=*/{}};
return success();
}
// 5c. Tile the cloned operation.
tilingResult = getTiledImplementation(rewriter, clonedOp, regionIterArgs,
offsets, sizes, options);
if (failed(tilingResult)) {
rewriter.eraseOp(clonedOp);
return op.emitOpError("faild to tile operation");
}
// 5d. Delete the cloned operation.
rewriter.eraseOp(clonedOp);
// 5e. Compute the offsets at which the result values are to be inserted
// back into its destinations.
for (auto [index, tiledValue] :
llvm::enumerate(tilingResult->tiledValues)) {
tiledResults.push_back(tiledValue);
SmallVector<OpFoldResult> resultOffset, resultSize;
if (failed(getResultTilePosition(rewriter, index, tiledValue, op, offsets,
sizes, resultOffset, resultSize,
options))) {
for (auto op : tilingResult->tiledOps) {
rewriter.eraseOp(op);
}
return rewriter.notifyMatchFailure(
op, "failed to get slice of result produced");
}
resultOffsets.emplace_back(std::move(resultOffset));
resultSizes.emplace_back(std::move(resultSize));
}
return success();
};
// 6. Find the destination tensors to use for the operation.
FailureOr<SmallVector<Value>> maybeInits =
createInitialTensorsForTiling(rewriter, op, tileSizes, options);
if (failed(maybeInits)) {
return rewriter.notifyMatchFailure(
op, "unable to create initial tensors for tiling");
}
SmallVector<Value> &initTensors = maybeInits.value();
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
tileSizes, numThreads, initTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
SmallVector<Value> partialResults;
if (loops.empty()) {
// If loops are empty, the tiled op is used as the replacement for the
// untiled op.
partialResults = tilingResult->tiledValues;
} else {
partialResults = llvm::map_to_vector(loops.front()->getResults(),
[](OpResult r) -> Value { return r; });
}
FailureOr<MergeResult> mergeResult =
mergeTilingResults(rewriter, op, partialResults, options);
if (failed(mergeResult)) {
return rewriter.notifyMatchFailure(
op, "Failed to merge partial results from tiling");
}
return scf::SCFTilingResult{tilingResult->tiledOps, initTensors, loops,
mergeResult.value(),
tilingResult->generatedSlices};
}
FailureOr<scf::SCFTilingResult>
mlir::scf::tileReductionUsingScf(RewriterBase &b,
PartialReductionOpInterface op,
ArrayRef<OpFoldResult> tileSizes) {
SCFTilingOptions options;
options.setLoopType(SCFTilingOptions::LoopType::ForOp);
options.setReductionTilingStrategy(SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction);
options.setTileSizes(tileSizes);
TilingInterface tilingInterfaceOp =
dyn_cast<TilingInterface>(op.getOperation());
if (!tilingInterfaceOp) {
return b.notifyMatchFailure(
op,
"Operation implementing PartialReductionOpInterface should implement "
"TilingInterface");
}
return tileUsingSCF(b, tilingInterfaceOp, options);
}
//===----------------------------------------------------------------------===//
// tileConsumerAndFuseProducersUsingSCF implementation.
//===----------------------------------------------------------------------===//
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
/// `iter_args` of the outer most that is encountered. Traversing the
/// iter_args indicates that this is a destination operand of the consumer. If
/// there was no loop traversal needed, the second value of the returned tuple
/// is empty.
static std::tuple<OpResult, std::optional<OpOperand *>>
getUntiledProducerFromSliceSource(OpOperand *source,
ArrayRef<LoopLikeOpInterface> loops) {
std::optional<OpOperand *> destinationIterArg;
assert(!loops.empty() && "expected non empty loops container");
auto loopIt = loops.rbegin();
while (loopIt != loops.rend() && isa<BlockArgument>(source->get())) {
auto iterArg = cast<BlockArgument>(source->get());
auto loop = *loopIt;
if (iterArg.getOwner()->getParentOp() != loop)
break;
source = loop.getTiedLoopInit(iterArg);
loopIt++;
}
if (loopIt == loops.rend())
destinationIterArg = source;
return {dyn_cast<OpResult>(source->get()), destinationIterArg};
}
/// Implementation of fusing producer of a single slice by computing the
/// slice of the producer in-place.
std::optional<scf::SCFFuseProducerOfSliceResult>
mlir::scf::tileAndFuseProducerOfSlice(
RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
loops);
if (!fusableProducer)
return std::nullopt;
unsigned resultNumber = fusableProducer.getResultNumber();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
// 2. Clone the fused producer
// 2a. Compute the destination operands to use for the cloned operation.
SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
Operation *fusableProducerOp = fusableProducer.getOwner();
if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
failed(tensor::getOrCreateDestinations(
rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
origDestinationTensors)))
return std::nullopt;
clonedOpDestinationTensors = origDestinationTensors;
if (destinationInitArg &&
isa<DestinationStyleOpInterface>(fusableProducerOp)) {
// 2b. If the producer is also destination style, then to maintain the
// destination passing style, update the destination of the producer to be
// the source of the slice.
clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
}
// 2c. Clone the fused producer.
Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
rewriter, fusableProducerOp, clonedOpDestinationTensors);
// 2d. Update the source of the candidateSlice to be the cloned producer.
// Easier to just clone the slice with different source since
// replacements and DCE of cloned ops becomes easier
SmallVector<Value> candidateSliceOpOperands =
llvm::to_vector(candidateSliceOp->getOperands());
candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
tensor::ExtractSliceOp clonedCandidateSliceOp =
mlir::clone(rewriter, candidateSliceOp,
candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
// 3. Generate the tiled implementation of the producer of the source
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceExtractSliceWithTiledProducer(
rewriter, clonedCandidateSliceOp,
clonedProducerOp->getResult(resultNumber));
if (failed(tileAndFuseResult))
return std::nullopt;
// Note: Do not delete the candidateSliceOp, since its passed in from the
// caller.
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);
rewriter.eraseOp(clonedCandidateSliceOp);
rewriter.eraseOp(clonedProducerOp);
// 3. If the slice is for a destination operand, for example,
//
// ```mlir
// %0 = linalg.init
// %1 = linalg.fill .. outs(%0 : )
// %2 = scf.for .. iter_args(%arg0 = %1) {
// %3 = scf.for .. iter_args(%arg1 = %arg0) {
// %4 = tensor.extract_slice %arg1 [..]
// .. = linalg.matmul .. outs(%4 : )
// }
// }
// ```
//
// the IR is currently
//
// ```
// %0 = linalg.init
// %1 = linalg.fill
// %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
// %3 = scf.for .. iter_args(%arg1 = %arg0) {
// %4 = tensor.extract_slice %arg1[..]
// %5 = linalg.fill .. outs(%4 : )
// .. = linalg.matmul .. outs(%5 : )
// }
// }
// ```
//
// The untiled `linalg.fill` is still used as the `init_value` since it
// was originally a destination operand of the untiled `linalg.matmul`.
// When fusing an operand that is a destination operand, the iter_arg of
// the outer most loop should be changed to use the destination of the
// fused operation. With this the IR will be.
//
// ```
// %0 = linalg.init
// %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
// %2 = scf.for .. iter_args(%arg1 = %arg0) {
// %3 = tensor.extract_slice %arg1[..]
// %4 = linalg.fill .. outs(%3 : )
// .. = linalg.matmul .. outs(%4 : )
// }
// }
// ```
if (destinationInitArg &&
isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
loops.front()
->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
.set(origDestinationTensors[resultNumber]);
}
return scf::SCFFuseProducerOfSliceResult{
fusableProducer, tileAndFuseResult->tiledValues[0],
tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices};
}
/// Reconstruct the fused producer from within the tiled-and-fused code.
FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<LoopLikeOpInterface> loops,
ArrayRef<unsigned> yieldResultNumber) {
if (loops.empty())
return success();
Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
*tiledOwner = fusedProducerInfo.tiledOps[0];
Location loc = originalOwner->getLoc();
// a. collect all init Value to be appended
SmallVector<unsigned> initNumberList =
yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
0, originalOwner->getNumResults()))
: llvm::to_vector(yieldResultNumber);
SmallVector<Value> initValueList;
for (const auto &resultNumber : initNumberList) {
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, loc, originalOwner->getResult(resultNumber));
if (succeeded(initValue)) {
initValueList.push_back(initValue.value());
} else {
return failure();
}
}
SmallVector<Operation *> generatedSlices;
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
// get sliceOp tile information
SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
return !isConstantIntValue(ofr, 1);
}))
return failure();
unsigned sliceResultNumber =
fusedProducerInfo.origProducer.getResultNumber();
auto tilableOp = cast<TilingInterface>(originalOwner);
// b. get iterDomain Offset and Sizes based on sliceOp tile
SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
// skip tensor.pack/unpack/pad, which expects single opResult
if (tilableOp->getNumResults() > 1 &&
failed(tilableOp.getIterationDomainTileFromResultTile(
rewriter, sliceResultNumber, sliceOffset, sliceSizes,
iterDomainOffset, iterDomainSizes))) {
// In theory, it is unnecessary to raise an error here. Actually
// although it fails to reconstruct the result tensor, it should not
// broke current fusion anyway. The reason why we must return failure
// currently is that the callback function `newYieldValuesFn` will be
// called after new init operand(s) has already been appended. It will
// take more refactoring to make sure the init operands are added
// consistently in the future. For more details, please refer to:
// https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
return failure();
}
// c. calculate offsets and sizes info of all OpResults respectively based
// on iteration Domain Tile
SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
for (const auto &resultNumber : initNumberList) {
if (resultNumber == sliceResultNumber) {
offsetList.push_back(sliceOffset);
sizesList.push_back(sliceSizes);
} else {
assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
// infer result tile according to the iteration domain tile
SmallVector<OpFoldResult> offset, sizes;
if (failed(tilableOp.getResultTilePosition(
rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
offset, sizes))) {
return failure();
}
offsetList.push_back(offset);
sizesList.push_back(sizes);
}
}
// d. create `extract_slice` for `iter_args` for DPS operation if
// necessary
if (auto tiledDestStyleOp =
dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, newRegionArg, offsetList[index], sizesList[index],
SmallVector<OpFoldResult>(offsetList[index].size(),
rewriter.getIndexAttr(1)));
generatedSlices.push_back(destSlice);
unsigned resultNumber = initNumberList[index];
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
}
}
// e. prepare tiled offset and sizes for later `insert_slice` creation by
// caller
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
tiledResult.push_back(tiledOwner->getResult(resultNumber));
tiledOffset.emplace_back(offsetList[index]);
tiledSizes.emplace_back(sizesList[index]);
}
return success();
};
if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList,
newYieldValuesFn))) {
return failure();
}
return generatedSlices;
}
namespace {
//===----------------------------------------------------------------------===//
// SliceTrackingListener
//===----------------------------------------------------------------------===//
/// This class is a listener for tracking the insertion and removal of
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
/// fusion algorithm to apply cleanup patterns in between fusion steps.
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
SliceTrackingListener() = default;
/// Adds the given list of operations to the worklist, and if present,
/// applies the list of `patterns` to the newly added operations. This only
/// processes the given operations and any newly inserted ones by the
/// pattern set.
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
/// Add to the new operation worklist if it is an extract_slice.
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;
/// Shared helper for operation removal from the worklist.
void removeOp(Operation *op);
/// Remove the operation from the worklist.
void notifyOperationErased(Operation *op) override;
/// Remove the operation from the worklist.
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
/// The worklist for this transformation keeps track of the slices to visit
/// next for fusion.
std::deque<tensor::ExtractSliceOp> worklist;
private:
/// Optional pattern set to apply when adding new operations to the
/// worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
};
SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
patterns = std::move(p);
}
LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
}
if (!patterns)
return success();
GreedyRewriteConfig config;
config.listener = this;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
return applyOpPatternsGreedily(ops, patterns.value(), config);
}
void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
}
// Scan the worklist for the given op and remove it if present. The
// expectation is for the worklist to be small and for removal to be
// relatively rare.
void SliceTrackingListener::removeOp(Operation *op) {
if (!isa<tensor::ExtractSliceOp>(op))
return;
auto iter = worklist.begin();
while (iter != worklist.end()) {
if (*iter == op)
break;
iter++;
}
if (iter == worklist.end())
return;
worklist.erase(iter);
}
void SliceTrackingListener::notifyOperationErased(Operation *op) {
removeOp(op);
}
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
ValueRange replacement) {
removeOp(op);
}
//===----------------------------------------------------------------------===//
// ReplacementListener
//===----------------------------------------------------------------------===//
/// Listener that tracks updates replacements for values which can be mutated.
/// This listener runs on top of the existing listener for the rewriter,
/// to make sure external users can still run listeners.
class ReplacementListener : public RewriterBase::ForwardingListener {
public:
ReplacementListener(DenseMap<Value, Value> &replacements,
OpBuilder::Listener *listener)
: ForwardingListener(listener), replacements(replacements) {}
void updateReplacementValues(ValueRange origValues,
ValueRange replaceValues) {
// This can probably be written better, but just iterates over the map
// and the new replacements for now.
for (auto &[key, val] : replacements) {
for (auto [orig, replace] : llvm::zip_equal(origValues, replaceValues)) {
if (val == orig) {
val = replace;
}
}
}
}
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
ForwardingListener::notifyOperationReplaced(op, newOp);
updateReplacementValues(op->getResults(), newOp->getResults());
}
void notifyOperationReplaced(Operation *op, ValueRange values) override {
ForwardingListener::notifyOperationReplaced(op, values);
updateReplacementValues(op->getResults(), values);
}
private:
DenseMap<Value, Value> &replacements;
};
} // namespace
/// Implementation of tile consumer and fuse producer greedily.
FailureOr<scf::SCFTileAndFuseResult>
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
RewriterBase &rewriter, TilingInterface consumer,
const scf::SCFTileAndFuseOptions &options) {
// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
if (!consumer->getNumResults()) {
return rewriter.notifyMatchFailure(
consumer, "invalid pattern for op with no results");
}
// 1. First tile the consumer.
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
llvm::SmallDenseMap<Value, size_t> origProducerToLoopResultNum;
FailureOr<scf::SCFTilingResult> tilingResult =
tileUsingSCF(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
tiledAndFusedOps.insert_range(tilingResult->tiledOps);
DenseMap<Value, Value> replacements;
for (auto [origVal, replacement] : llvm::zip_equal(
consumer->getResults(), tilingResult->mergeResult.replacements)) {
replacements[origVal] = replacement;
}
// If there are no loops generated, fusion is immaterial.
auto &loops = tilingResult->loops;
if (loops.empty()) {
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
// Since the loop gets potentially replaced during fusion, we need to track
// the mutation of replacement values. To do this, we attach a listener to
// update the replacements as they happen.
OpBuilder::Listener *previousListener = rewriter.getListener();
auto resetListener =
llvm::make_scope_exit([&]() { rewriter.setListener(previousListener); });
ReplacementListener replaceListener(replacements, previousListener);
rewriter.setListener(&replaceListener);
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
// `tensor.extract_slice` operations with source being the operands of
// the untiled operation. Create a worklist of these
// `tensor.extract_slice` operations. If the producers of the source of
// the `tensor.extract_slice` can be tiled such that the tiled value is
// generated in-place, that effectively tiles + fuses the operations.
struct WorklistItem {
tensor::ExtractSliceOp candidateSlice;
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
};
SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);
if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
OpBuilder::InsertionGuard g(rewriter);
while (!sliceTracker.worklist.empty()) {
auto candidateSlice = sliceTracker.worklist.front();
sliceTracker.worklist.pop_front();
auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
loops);
if (!fusableProducer)
continue;
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
destinationInitArg.has_value());
if (!controlFnResult)
continue;
WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice,
loops);
if (!fusedResult)
continue;
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
if (worklistItem.controlFnResult.yieldProducerReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default.
// The caller can specific which one to yield by designating optional
// argument named `yieldResultNumber` of
// `yieldReplacementForFusedProducer`.
Operation *fusableProducerOp = fusedResult->origProducer.getOwner();
FailureOr<SmallVector<Operation *>> newSlices =
yieldReplacementForFusedProducer(rewriter,
worklistItem.candidateSlice,
fusedResult.value(), loops);
if (failed(newSlices)) {
return rewriter.notifyMatchFailure(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
worklistCandidates.append(newSlices.value());
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
replacements[result] = loops.front()->getResult(
loops.front()->getNumResults() -
fusableProducerOp->getNumResults() + index);
}
}
if (Operation *tiledAndFusedOp =
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
}
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
}
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps, loops,
replacements};
}
//===----------------------------------------------------------------------===//
// tileAndFuseConsumerUsingSCF implementation.
//===----------------------------------------------------------------------===//
/// A utility function that checks whether the only use of the result of a
/// tensor.insert_slice op is in a scf.yield op.
static LogicalResult
checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
Value result = candidateSliceOp.getResult();
Value::use_range uses = result.getUses();
if (!llvm::hasSingleElement(uses)) {
LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
return failure();
}
OpOperand &operandUse = (*uses.begin());
Operation *userOp = operandUse.getOwner();
if (!isa<scf::YieldOp>(userOp)) {
LLVM_DEBUG(llvm::dbgs()
<< "Expected scf.yield to be the only user, but got -> "
<< (*userOp));
return failure();
}
if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
"be in the same block\n");
return failure();
}
return success();
}
/// An utility to get the first user of the given loopOp. If any of user stay
/// in different block of loopOp, return failure.
static FailureOr<Operation *> getFirstUserOfLoop(Operation *loopOp) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
Operation *firstUserOfLoop = nullptr;
for (Operation *userOp : loopOp->getUsers()) {
// `ParallelInsertSlice` located inside `InParallelOp` has no same parent
// block with any other types of operation. Thus, just redirecting to its
// parent `InParallelOp`. E.g.
//
// ```
// %1 = scf.for {
// ...
// }
// %2 = consumerOp ins(%1, ...)
// scf.forall.in_parallel {
// tensor.parallel_insert_slice %1
// }
// ```
// where `InParallelOp` but not `ParallelInsertSlice` stays in the same
// same block with `consumerOp`.
if (isa<tensor::ParallelInsertSliceOp>(userOp))
userOp = userOp->getParentOfType<scf::InParallelOp>();
if (loopOp->getBlock() != userOp->getBlock())
return failure();
if (!firstUserOfLoop || userOp->isBeforeInBlock(firstUserOfLoop))
firstUserOfLoop = userOp;
}
return firstUserOfLoop;
}
/// This utility currently checks whether the first userOp of loop is NOT
/// before the last defineOp of consumer operand. Because that we need to move
/// the whole loop structure right before the `firstUserOfLoop`. This utility
/// thus helps ensuring that no invalid IR is formed, i.e. no backward slice
/// of consumerOp is dominated by the `firstUserOfLoop`. Saying that:
///
/// ```
/// %0 = scf.for() {
/// ...
/// }
/// ...
/// %1 = firstUserOfLoop(%0)
/// ...
/// %2 = lastDefOfConsumerOperand
/// ...
/// %3 = consumerOp(%2)
/// ```
///
/// If the `firstUserOfLoop` is before `lastDefOfConsumerOperand`, then it
/// would be invalid to move the `loopOp` right before the `firstUserOfLoop`,
/// a.k.a. use-def chain violation:
///
/// ```
/// %0:2 = scf.for() {
/// // use before define error
/// %3 = tiledConsumerOp(%2)
/// }
/// %1 = firstUserOfLoop(%0)
/// ...
/// %2 = lastDefOfConsumerOperand
/// ```
///
/// @param loopOp: loop operation
/// @param consumerOp: consumer operation
/// @param reorderOperations: the flag controls whether to reorder the
/// backward slice w.r.t. the defineOp of `consumerOp` operands.
/// @return: computed backward slice of consumerOp, but excluding those
/// already dominates `firstUserOfLoop`.
static FailureOr<llvm::SetVector<Operation *>>
checkAssumptionForLoop(Operation *loopOp, Operation *consumerOp,
bool reorderOperations) {
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
if (failed(firstUserOfLoop))
return failure();
BackwardSliceOptions options;
DominanceInfo dominanceInfo;
options.inclusive = true;
options.omitBlockArguments = true;
bool includeLoopOp = false;
options.filter = [&](Operation *op) {
if (op == loopOp) {
includeLoopOp = true;
return false;
}
// Cut off the slice to not include any operation that already dominates
// firstUserOfLoop.
return !dominanceInfo.properlyDominates(op, *firstUserOfLoop);
};
llvm::SetVector<Operation *> slice;
for (auto operand : consumerOp->getOperands()) {
getBackwardSlice(operand, &slice, options);
}
if (!slice.empty()) {
// If consumerOp has one producer, which is also the user of loopOp.
// E.g.
// ```
// %0 = %loopOp
// %1 = consumerOp1 ins(%0)
// %2 = consumerOp2 ins(%0, %1)
// ```
// We can not fuse consumerOp2 into loopOp due to UD chain, unless
// consumerOp1 has already been fused into loopOp before.
if (includeLoopOp || !reorderOperations)
return failure();
}
return slice;
}
/// Fetches the OpOperand of the first valid user (and use) of the value `val`
/// which implements `TilingInterface` and `DestinationStyleOpInterface`.
/// Returns failure otherwise.
static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter,
Operation *loopOp,
unsigned resultNumber) {
if (!isa<LoopLikeOpInterface>(loopOp))
return failure();
Value val = loopOp->getResult(resultNumber);
Block *loopBlock = loopOp->getBlock();
for (OpOperand &opOperand : val.getUses()) {
Operation *consumerOp = opOperand.getOwner();
// Step 1. Check if the user is tilable.
if (!isa<TilingInterface>(consumerOp) ||
!isa<DestinationStyleOpInterface>(consumerOp)) {
// TODO: We have to init result of consumer before scf.for, use
// DestinationStyleOpInterface to get result shape from init for now.
// Add support for other op such as op has InferTypeOpInterface.
continue;
}
// Step 2. Check if user stay in the same block.
if (loopBlock != consumerOp->getBlock())
continue;
// Step 3. Check if user has succeeding user. Otherwise, it usually
// represents already tiled.
if (consumerOp->use_empty())
continue;
// Step 4. Check assumption for loop with `reorderOperations` enabled.
FailureOr<llvm::SetVector<Operation *>> slice =
checkAssumptionForLoop(loopOp, consumerOp, true);
if (failed(slice))
continue;
// Step 5. If backward sice is not empty, move them before
// firstUserOfLoop.
if (!slice->empty()) {
mlir::topologicalSort(*slice);
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(loopOp);
assert(succeeded(firstUserOfLoop) && "First user of loop is not found");
for (auto op : *slice) {
rewriter.moveOpBefore(op, *firstUserOfLoop);
}
}
return &opOperand;
}
return failure();
}
/// Check that the loop is perfectly nested.
/// The loops are expected to be ordered from outer most to inner most.
/// For example:
/// ```
/// %0 = scf.for()
/// %1 = scf.for()
/// %2 = scf.for()
/// %3 = ...
/// yield %3
/// yield %2
/// yield %1
/// ```
/// Here loops should be [%0, %1].
static bool
isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loop nest");
if (loops.size() == 1) {
return isa_and_nonnull<scf::ForOp>(loops.front().getOperation());
}
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation());
auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation());
if (!outerFor || !innerFor) {
return false;
}
auto outerBBArgs = outerFor.getRegionIterArgs();
auto innerIterArgs = innerFor.getInitArgs();
if (outerBBArgs.size() != innerIterArgs.size()) {
return false;
}
for (auto [outerBBArg, innerIterArg] :
llvm::zip_equal(outerBBArgs, innerIterArgs)) {
if (!llvm::hasSingleElement(outerBBArg.getUses()) ||
innerIterArg != outerBBArg) {
return false;
}
}
ValueRange outerYields =
cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands();
ValueRange innerResults = innerFor.getResults();
if (outerYields.size() != innerResults.size()) {
return false;
}
for (auto [outerYield, innerResult] :
llvm::zip_equal(outerYields, innerResults)) {
if (!llvm::hasSingleElement(innerResult.getUses()) ||
outerYield != innerResult) {
return false;
}
}
}
return true;
}
/// Fetch the untiled consumer of the outermost scf.for's result which is
/// yielded by a tensor.insert_slice from the innermost scf.for. This function
/// makes the following assumptions :
/// 1. tensor.insert_slice has scf.yield as its only user.
/// 2. scf.for's corresponding result has only one use.
/// 3. The `loops` passed in are perfectly nested `scf.for` operations.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::InsertSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected loops to be empty");
// 1. Expect slice to be part of the body of the inner most loop.
Operation *containingOp = candidateSliceOp->getParentOp();
if (containingOp != loops.back()) {
return rewriter.notifyMatchFailure(
candidateSliceOp,
"expected slice to be within body of inner-most loop");
}
// 2. Check that the loop is perfectly nested.
if (!isPerfectlyNestedForLoops(loops)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected passed loops to be perfectly nested.");
}
if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
return failure();
Value sliceResult = candidateSliceOp.getResult();
// 3. Fetch the corresponding output.
OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
unsigned resultNumber = yieldOpOperand.getOperandNumber();
scf::ForOp topLevelForOp = cast<scf::ForOp>(loops.front().getOperation());
return getConsumerFromLoopUses(rewriter, topLevelForOp, resultNumber);
}
/// Fetch the first untiled consumer of a scf.forall's result which is yielded
/// by a tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter,
tensor::ParallelInsertSliceOp candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected loops to be empty");
// 1. Check that the surrounding loop is a single scf.forall loop.
if (loops.size() != 1) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected single surrounding scf.forall");
}
auto forallOp = dyn_cast<scf::ForallOp>(loops.front().getOperation());
if (!forallOp) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "expected single surrounding scf.forall");
}
// 2. Fetch the corresponding output
Value sliceDest = candidateSliceOp.getDest();
auto iterArg = dyn_cast<BlockArgument>(sliceDest);
if (!iterArg)
return failure();
if (iterArg.getOwner()->getParentOp() != forallOp)
return failure();
unsigned resultNumber =
forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg))
.getResultNumber();
return getConsumerFromLoopUses(rewriter, forallOp, resultNumber);
}
/// A utility to fetch an untiled consumer of
/// tensor.insert_slice/tensor.parallel_insert_slice.
static FailureOr<OpOperand *>
getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "unexpected empty loops");
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, insertSlice, loops);
} else if (auto parallelInsertSlice =
dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
return getUntiledConsumerFromSlice(rewriter, parallelInsertSlice, loops);
} else {
return failure();
}
}
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlice(
RewriterBase &rewriter, Operation *candidateSliceOp,
MutableArrayRef<LoopLikeOpInterface> loops) {
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return candidateSliceOp->emitOpError(
"cannot call tile and fuse consumer with an empty loop nest");
}
if (!isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
candidateSliceOp))
return failure();
// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
FailureOr<OpOperand *> maybeConsumerOpOperand =
getUntiledConsumerFromSlice(rewriter, candidateSliceOp, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSliceOp,
"could not fetch consumer to fuse");
}
OpOperand *consumerOpOperand = *maybeConsumerOpOperand;
Operation *consumerOp = consumerOpOperand->getOwner();
unsigned operandNumber = consumerOpOperand->getOperandNumber();
unsigned resultNumber = 0;
if (auto producerResult = dyn_cast<OpResult>(consumerOpOperand->get())) {
resultNumber = producerResult.getResultNumber();
} else {
return rewriter.notifyMatchFailure(
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();
// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
return rewriter.notifyMatchFailure(
outerMostLoop, "the first user of loop should not dominate any define "
"of consumer operand(s)");
}
OpBuilder::InsertionGuard g(rewriter);
// 2. Check consumer is not using scf loop's output as init.
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstOp)
return rewriter.notifyMatchFailure(consumerOp,
"consumer op is not DPS operation");
SmallVector<Value> dpsInits =
llvm::map_to_vector(dstOp.getDpsInits(), [](Value v) { return v; });
if (llvm::is_contained(dpsInits, outerMostLoop->getResult(resultNumber))) {
return rewriter.notifyMatchFailure(
consumerOp,
"consumer op taking the result of scf.for as init is not supported");
}
SmallVector<Value> newInits = dpsInits;
Location loc = outerMostLoop->getLoc();
// 3. Move the whole loop structure right before firstUserOfLoop, the
// dominance should be already ensured by `checkAssumptionForLoop`.
FailureOr<Operation *> firstUserOfLoop = getFirstUserOfLoop(outerMostLoop);
if (failed(firstUserOfLoop)) {
return rewriter.notifyMatchFailure(
outerMostLoop, "could not find the first user of outer most loop");
}
rewriter.moveOpBefore(outerMostLoop, *firstUserOfLoop);
// 4. Set insertion point before terminator op of the loop and create a new
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
tensor::InsertSliceOp clonedInsertSliceOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
rewriter.setInsertionPoint(candidateSliceOp);
clonedInsertSliceOp =
cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
}
// 5.a. Clone consumer op.
auto clonedConsumerOp = cast<TilingInterface>(rewriter.clone(*consumerOp));
// 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
operandToReplace.set(clonedInsertSliceOp.getResult());
});
// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
auto ossSliceOp =
cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
if (failed(tileAndFuseResult)) {
return failure();
}
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
clonedInsertSliceOp.getSource());
// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
[&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
OpBuilder::InsertionGuard g(innerRewriter);
// 8. Set inner insertPoint right before tiled consumer op.
innerRewriter.setInsertionPoint(tiledConsumerOp);
SmallVector<OpFoldResult> offsets = ossSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> sizes = ossSliceOp.getMixedSizes();
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
if (llvm::any_of(strides, [](OpFoldResult stride) {
return !isConstantIntValue(stride, 1);
})) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
// 10. Try to get iter domain position from input position. Use
// clonedConsumerOp instead of tiledConsumerOp, because the iteration
// domain may require index computation based on the result size. The
// sizes and offsets should be the same either way, but using
// tiledConsumerOp could lead to some chained unnecessary extra index
// computation.
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
iterDomainSizes))) {
return rewriter.notifyMatchFailure(
clonedConsumerOp,
"can't get iter domain position from input position");
}
// 11. Try to fetch the offset and size for all results of the cloned
// consumer. This would then be used to form the corresponding
// tensor.insert_slice/parallel_insert_slice later.
unsigned totalNumResultsOfConsumer = tiledConsumerOp->getNumResults();
SmallVector<SmallVector<OpFoldResult>> resultOffsets(
totalNumResultsOfConsumer);
SmallVector<SmallVector<OpFoldResult>> resultSizes(
totalNumResultsOfConsumer);
for (auto [idx, v] : llvm::enumerate(tiledConsumerOp->getResults())) {
if (failed(tiledConsumerOp.getResultTilePosition(
rewriter, idx, iterDomainOffsets, iterDomainSizes,
resultOffsets[idx], resultSizes[idx]))) {
return rewriter.notifyMatchFailure(
tiledConsumerOp,
"can't get result domain position from iter domain position");
}
}
// 12. Create `extract_slice` for `iter_args` for DPS operation if
// necessary.
if (auto tiledDestStyleOp = dyn_cast<DestinationStyleOpInterface>(
tiledConsumerOp.getOperation())) {
rewriter.setInsertionPoint(tiledDestStyleOp);
for (const auto &&[index, newRegionArg] :
llvm::enumerate(newRegionIterArgs)) {
auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, newRegionArg, resultOffsets[index], resultSizes[index],
SmallVector<OpFoldResult>(resultOffsets[index].size(),
rewriter.getIndexAttr(1)));
// Make a copy of index to avoid a capturing structured binding, which
// is a C++20 extension.
auto dstNumber = index;
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[dstNumber].set(destSlice);
});
}
}
// 13. Prepare tiled offset and sizes for later `insert_slice` creation by
// caller.
Block *block = rewriter.getInsertionPoint()->getBlock();
rewriter.setInsertionPoint(block->getTerminator());
for (const auto &&[index, result] :
llvm::enumerate(tiledConsumerOp->getResults())) {
tiledResult.push_back(result);
tiledOffset.emplace_back(resultOffsets[index]);
tiledSizes.emplace_back(resultSizes[index]);
}
return success();
};
// 14. Add new inits to [nested] loops.
if (failed(addInitOperandsToLoopNest(rewriter, loops, newInits,
newYieldValuesFn))) {
return rewriter.notifyMatchFailure(tiledConsumerOp,
"unable to add new inits to nest loop");
}
// 15. Replace the result of scf loop and consumer op with new loop's
// results.
for (auto &&[oldResult, newResult] :
llvm::zip(consumerOp->getResults(),
loops.front()->getResults().take_back(newInits.size()))) {
rewriter.replaceAllUsesWith(oldResult, newResult);
}
// 16. Need to erase the old scf loop and the cloned consumer op.
rewriter.eraseOp(clonedConsumerOp);
return scf::SCFFuseConsumerOfSliceResult{
consumerOpOperand,
&(tileAndFuseResult->tiledOps[0]->getOpOperand(operandNumber)),
tileAndFuseResult->tiledOps};
}
//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
FailureOr<SmallVector<scf::ForOp>>
mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
TilingInterface op) {
// TODO: Handle cases where the op has results if needed.
if (op->getNumResults() > 0) {
return rewriter.notifyMatchFailure(
op, "unable to lower to loops operations with return values");
}
SmallVector<Range> domain = op.getIterationDomain(rewriter);
SmallVector<Value> ivs;
SmallVector<scf::ForOp> loops;
Location loc = op.getLoc();
for (auto loopRange : domain) {
Value offsetVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
Value sizeVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
Value strideVal =
getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride);
auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal,
strideVal, ValueRange{});
loops.push_back(loop);
ivs.push_back(loop.getInductionVar());
rewriter.setInsertionPoint(loop.getBody()->getTerminator());
}
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
return failure();
}
return loops;
}