blob: 36434cf2d2ae25a6307bd265d76ec4bbe4bc507b [file] [log] [blame]
//===- Hoisting.cpp - Linalg hoisting transformations ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements functions concerned with hoisting invariant operations
// in the context of Linalg transformations.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h"
#include "llvm/Support/Debug.h"
using llvm::dbgs;
#define DEBUG_TYPE "linalg-hoisting"
#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ")
using namespace mlir;
using namespace mlir::linalg;
/// Replace `loop` with a new loop that has a different init operand at
/// position `index`. The body of this loop is moved over to the new loop.
///
/// `newInitOperands` specifies the replacement "init" operands.
/// `newYieldValue` is the replacement yield value of the loop at position
/// `index`.
static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp loop,
Value newInitOperand,
unsigned index,
Value newYieldValue) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop.getOperation());
auto inits = llvm::to_vector(loop.getInits());
// Replace the init value with the new operand.
assert(index < inits.size());
inits[index] = newInitOperand;
scf::ForOp newLoop = scf::ForOp::create(
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
loop.getUnsignedCmp());
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
yieldOp.setOperand(index, newYieldValue);
// Move the loop body to the new op.
rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(),
newLoop.getBody()->getArguments());
// Replace the old loop.
rewriter.replaceOp(loop.getOperation(), newLoop->getResults());
return newLoop;
}
// Hoist out a pair of corresponding vector.extract+vector.broadcast
// operations. This function transforms a loop like this:
// %res = scf.for _ = _ to _ step _ iter_args(%iarg = %v) -> (t1) {
// %e = vector.extract %iarg : t1 to t2
// %u = "some_use"(%e) : (t2) -> t2
// %b = vector.broadcast %u : t2 to t1
// scf.yield %b : t1
// }
// into the following:
// %e = vector.extract %v: t1 to t2
// %res' = scf.for _ = _ to _ step _ iter_args(%iarg = %e) -> (t2) {
// %u' = "some_use"(%iarg) : (t2) -> t2
// scf.yield %u' : t2
// }
// %res = vector.broadcast %res' : t2 to t1
void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
Operation *root) {
bool changed = true;
while (changed) {
changed = false;
// First move loop invariant ops outside of their loop. This needs to be
// done before as we cannot move ops without interrupting the function walk.
root->walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
root->walk([&](vector::ExtractOp extractOp) {
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *extractOp.getOperation() << "\n");
auto loop = dyn_cast<scf::ForOp>(extractOp->getParentOp());
if (!loop)
return WalkResult::advance();
// Check that the vector to extract from is a BlockArgument.
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
if (!blockArg)
return WalkResult::advance();
// Check that the blockArg is an iter_arg of the loop.
OpOperand *initArg = loop.getTiedLoopInit(blockArg);
if (!initArg)
return WalkResult::advance();
// If the iter_arg does not have only one use, it won't be possible to
// hoist the extractOp out.
if (!blockArg.hasOneUse())
return WalkResult::advance();
unsigned index = blockArg.getArgNumber() - loop.getNumInductionVars();
// Check that the loop yields a broadcast that has just one use.
Operation *yieldedVal =
loop.getTiedLoopYieldedValue(blockArg)->get().getDefiningOp();
auto broadcast = dyn_cast<vector::BroadcastOp>(yieldedVal);
if (!broadcast || !broadcast.getResult().hasOneUse())
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate broadcast: " << broadcast << "\n");
Type broadcastInputType = broadcast.getSourceType();
if (broadcastInputType != extractOp.getType())
return WalkResult::advance();
// The position of the extract must be defined outside of the loop if
// it is dynamic.
for (auto operand : extractOp.getDynamicPosition())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
rewriter.modifyOpInPlace(broadcast, [&] {
extractOp.getVectorMutable().assign(initArg->get());
});
loop.moveOutOfLoop(extractOp);
rewriter.moveOpAfter(broadcast, loop);
scf::ForOp newLoop = replaceWithDifferentYield(
rewriter, loop, extractOp.getResult(), index, broadcast.getSource());
LLVM_DEBUG(DBGS() << "New loop: " << newLoop << "\n");
rewriter.replaceAllUsesWith(newLoop.getResult(index), broadcast);
rewriter.modifyOpInPlace(
broadcast, [&] { broadcast.setOperand(newLoop.getResult(index)); });
changed = true;
return WalkResult::interrupt();
});
}
}
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
LoopLikeOpInterface loop) {
Value source = transferRead.getBase();
// Skip view-like Ops and retrive the actual soruce Operation
while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) {
if (viewLike.getViewDest() != source) {
break;
}
source = viewLike.getViewSource();
}
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
source.getUsers().end());
llvm::SmallDenseSet<Operation *, 32> processed;
while (!users.empty()) {
Operation *user = users.pop_back_val();
// If the user has already been processed skip.
if (!processed.insert(user).second)
continue;
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) {
Value viewDest = viewLike.getViewDest();
users.append(viewDest.getUsers().begin(), viewDest.getUsers().end());
continue;
}
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
continue;
if (!loop->isAncestor(user))
continue;
return false;
}
return true;
}
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
bool verifyNonZeroTrip) {
bool changed = true;
while (changed) {
changed = false;
// First move loop invariant ops outside of their loop. This needs to be
// done before as we cannot move ops without interrupting the function walk.
root->walk(
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
// Find all loops that are certain to have non zero trip count. Any loops
// that are not part of this set cannot be hoisted from, since hoisting from
// a potentially zero trip count loop may cause a vector transfer to be
// executed when it shouldn't be.
llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
if (verifyNonZeroTrip) {
root->walk([&](LoopLikeOpInterface loopLike) {
std::optional<SmallVector<OpFoldResult>> lbs =
loopLike.getLoopLowerBounds();
std::optional<SmallVector<OpFoldResult>> ubs =
loopLike.getLoopUpperBounds();
// If loop bounds cannot be found, assume possibly zero trip count.
if (!lbs || !ubs)
return;
// Otherwise, use ValueBounds to find the maximum lower bound and
// minimum upper bound. If the bounds are found, and maxLb is less
// than the minUb, then the loop will not have zero trip count.
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
FailureOr<int64_t> maxLb =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, lb,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(maxLb))
return;
FailureOr<int64_t> minUb =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::LB, ub);
if (failed(minUb))
return;
if (minUb.value() <= maxLb.value())
return;
definiteNonZeroTripCountLoops.insert(loopLike);
}
});
}
root->walk([&](vector::TransferReadOp transferRead) {
if (!isa<MemRefType>(transferRead.getShapedType()))
return WalkResult::advance();
LLVM_DEBUG(DBGS() << "Candidate for hoisting: "
<< *transferRead.getOperation() << "\n");
auto loop = dyn_cast<LoopLikeOpInterface>(transferRead->getParentOp());
LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp()
<< "\n");
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
return WalkResult::advance();
if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
<< "\n");
return WalkResult::advance();
}
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
<< "\n");
SetVector<Operation *> forwardSlice;
getForwardSlice(transferRead.getOperation(), &forwardSlice);
// Look for the last TransferWriteOp in the forwardSlice of
// `transferRead` that operates on the same memref.
vector::TransferWriteOp transferWrite;
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
if (!candidateWrite ||
candidateWrite.getBase() != transferRead.getBase())
continue;
transferWrite = candidateWrite;
}
// All operands of the TransferRead must be defined outside of the loop.
for (auto operand : transferRead.getOperands())
if (!loop.isDefinedOutsideOfLoop(operand))
return WalkResult::advance();
// Only hoist transfer_read / transfer_write pairs and singleton
// transfer_reads for now.
if (!transferWrite) {
// Make sure there are no other accesses to the memref before
// hoisting transfer_read.
if (noAliasingUseInLoop(transferRead, loop))
loop.moveOutOfLoop(transferRead);
return WalkResult::advance();
}
LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation()
<< "\n");
// Approximate aliasing by checking that:
// 1. indices, vector type and permutation map are the same (i.e., the
// transfer_read/transfer_write ops are matching),
// 2. source operands for transfer.{read|write} do not originate from
// nor have users that are Ops implementing ViewLikeOpInterface.
// 3. no other operations in the loop access the same memref except
// for transfer_read/transfer_write accessing statically disjoint
// slices.
// Check 1.
if (transferRead.getIndices() != transferWrite.getIndices() ||
transferRead.getVectorType() != transferWrite.getVectorType() ||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
return WalkResult::advance();
// Check 2. Note, since both xfer Ops share the source, we only need to
// look at one of them.
auto base = transferRead.getBase();
auto *source = base.getDefiningOp();
if (source) {
// NOTE: We treat `memref.assume_alignment` as a special case.
//
// The idea is that it is safe to look past AssumeAlignmemtOp (i.e.
// MemRef _before_ alignment) iff:
// 1. It has exactly two uses (these have to be the xfer Ops
// being looked at).
// 2. The original MemRef has only one use (i.e.
// AssumeAlignmentOp).
//
// Relaxing these conditions will most likely require proper alias
// analysis.
if (auto assume = dyn_cast<memref::AssumeAlignmentOp>(source)) {
Value memPreAlignment = assume.getMemref();
auto numInLoopUses =
llvm::count_if(base.getUses(), [&loop](OpOperand &use) {
return loop->isAncestor(use.getOwner());
});
if (numInLoopUses && memPreAlignment.hasOneUse())
source = memPreAlignment.getDefiningOp();
}
if (isa_and_nonnull<ViewLikeOpInterface>(source))
return WalkResult::advance();
}
if (llvm::any_of(base.getUsers(), llvm::IsaPred<ViewLikeOpInterface>))
return WalkResult::advance();
// Check 3.
// TODO: may want to memoize this information for performance but it
// likely gets invalidated often.
DominanceInfo dom(loop);
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
return WalkResult::advance();
for (auto &use : transferRead.getBase().getUses()) {
if (!loop->isAncestor(use.getOwner()))
continue;
if (use.getOwner() == transferRead.getOperation() ||
use.getOwner() == transferWrite.getOperation())
continue;
if (auto transferWriteUse =
dyn_cast<vector::TransferWriteOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(*transferWrite),
cast<VectorTransferOpInterface>(*transferWriteUse),
/*testDynamicValueUsingBounds=*/true))
return WalkResult::advance();
} else if (auto transferReadUse =
dyn_cast<vector::TransferReadOp>(use.getOwner())) {
if (!vector::isDisjointTransferSet(
cast<VectorTransferOpInterface>(*transferWrite),
cast<VectorTransferOpInterface>(*transferReadUse),
/*testDynamicValueUsingBounds=*/true))
return WalkResult::advance();
} else {
// Unknown use, we cannot prove that it doesn't alias with the
// transferRead/transferWrite operations.
return WalkResult::advance();
}
}
// Hoist read before.
loop.moveOutOfLoop(transferRead);
// Hoist write after.
transferWrite->moveAfter(loop);
// Rewrite `loop` with new yields by cloning and erase the original
// loop.
IRRewriter rewriter(transferRead.getContext());
NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
ArrayRef<BlockArgument> newBBArgs) {
return SmallVector<Value>{transferWrite.getVector()};
};
auto maybeNewLoop = loop.replaceWithAdditionalYields(
rewriter, transferRead.getVector(),
/*replaceInitOperandUsesInLoop=*/true, yieldFn);
if (failed(maybeNewLoop))
return WalkResult::interrupt();
transferWrite.getValueToStoreMutable().assign(
maybeNewLoop->getOperation()->getResults().back());
changed = true;
// Need to interrupt and restart because erasing the loop messes up
// the walk.
return WalkResult::interrupt();
});
}
}