blob: d925c198526793c1faf7a9cba1eb4c6bb64c8f1c [file] [log] [blame]
//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transform allocates SME tiles at the 'func.func' op level for ArmSME
// operations. It roughly implements a linear scan register allocator, similar
// to the one outlined in [1], but with simplifications and assumptions made for
// our use case. Note that this is a greedy allocator (so it may not always find
// the most optimal allocation of tiles).
//
// The allocator operates at the CF dialect level. It is the responsibility of
// users to ensure the IR has been lowered to CF before invoking the tile
// allocator.
//
// The 128-bit tiles overlap with other element tiles as follows (see section
// B2.3.2 of SME spec [2]):
//
// Tile Overlaps
// ---------------------------------------------------------------------------
// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
// ZA0.D ZA0.Q, ZA8.Q
// ZA1.D ZA1.Q, ZA9.Q
// ZA2.D ZA2.Q, ZA10.Q
// ZA3.D ZA3.Q, ZA11.Q
// ZA4.D ZA4.Q, ZA12.Q
// ZA5.D ZA5.Q, ZA13.Q
// ZA6.D ZA6.Q, ZA14.Q
// ZA7.D ZA7.Q, ZA15.Q
//
// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
// [2] https://developer.arm.com/documentation/ddi0616/aa
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Liveness.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir::arm_sme {
#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
enum class TileMask : unsigned {
// clang-format off
kZA0B = 0xffff, // 1111 1111 1111 1111
kZA0H = 0xaaaa, // 1010 1010 1010 1010
kZA1H = 0x5555, // 0101 0101 0101 0101
kZA0S = 0x8888, // 1000 1000 1000 1000
kZA1S = 0x4444, // 0100 0100 0100 0100
kZA2S = 0x2222, // 0010 0010 0010 0010
kZA3S = 0x1111, // 0001 0001 0001 0001
kZA0D = 0x8080, // 1000 0000 1000 0000
kZA1D = 0x4040, // 0100 0000 0100 0000
kZA2D = 0x2020, // 0010 0000 0010 0000
kZA3D = 0x1010, // 0001 0000 0001 0000
kZA4D = 0x808, // 0000 1000 0000 1000
kZA5D = 0x404, // 0000 0100 0000 0100
kZA6D = 0x202, // 0000 0010 0000 0010
kZA7D = 0x101, // 0000 0001 0000 0001
kZA0Q = 0x8000, // 1000 0000 0000 0000
kZA1Q = 0x4000, // 0100 0000 0000 0000
kZA2Q = 0x2000, // 0010 0000 0000 0000
kZA3Q = 0x1000, // 0001 0000 0000 0000
kZA4Q = 0x800, // 0000 1000 0000 0000
kZA5Q = 0x400, // 0000 0100 0000 0000
kZA6Q = 0x200, // 0000 0010 0000 0000
kZA7Q = 0x100, // 0000 0001 0000 0000
kZA8Q = 0x80, // 0000 0000 1000 0000
kZA9Q = 0x40, // 0000 0000 0100 0000
kZA10Q = 0x20, // 0000 0000 0010 0000
kZA11Q = 0x10, // 0000 0000 0001 0000
kZA12Q = 0x8, // 0000 0000 0000 1000
kZA13Q = 0x4, // 0000 0000 0000 0100
kZA14Q = 0x2, // 0000 0000 0000 0010
kZA15Q = 0x1, // 0000 0000 0000 0001
kNone = 0x0, // 0000 0000 0000 0000
// clang-format on
LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
};
/// Returns the set of masks relevant for the given type.
static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
TileMask::kZA2S, TileMask::kZA3S};
static constexpr std::array ZA_D_MASKS = {
TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
static constexpr std::array ZA_Q_MASKS = {
TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
switch (type) {
case ArmSMETileType::ZAB:
return ZA_B_MASKS;
case ArmSMETileType::ZAH:
return ZA_H_MASKS;
case ArmSMETileType::ZAS:
return ZA_S_MASKS;
case ArmSMETileType::ZAD:
return ZA_D_MASKS;
case ArmSMETileType::ZAQ:
return ZA_Q_MASKS;
}
llvm_unreachable("unknown type in getMasks");
}
class TileAllocator {
public:
/// Allocates and returns a tile ID. Fails if there are no tiles left.
FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
auto masks = getMasks(tileType);
for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
if ((tilesInUse & tileMask) == TileMask::kNone) {
tilesInUse |= tileMask;
return tileId;
}
}
return failure();
}
/// Acquires a specific tile ID. Asserts the tile is initially free.
void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == TileMask::kNone &&
"cannot acquire allocated tile!");
tilesInUse |= tileMask;
}
/// Releases a previously allocated tile ID.
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
assert((tilesInUse & tileMask) == tileMask &&
"cannot release unallocated tile!");
tilesInUse ^= tileMask;
}
/// Allocates an in-memory tile ID.
unsigned allocateInMemoryTileId() {
// Note: We never release in-memory tile IDs. We could, which may allow
// reusing an allocation, but as we _never_ want to spill an SME tile this
// is not optimized.
return nextInMemoryTileId++;
}
private:
TileMask tilesInUse = TileMask::kNone;
unsigned nextInMemoryTileId = kInMemoryTileIdBase;
};
/// Add new intermediate blocks for the true and false destinations of
/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
/// overlaps due to copies at branches.
///
/// BEFORE:
/// ```mlir
/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
/// ```
///
/// AFTER:
/// ```mlir
/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
/// ^bb1_copy:
/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
/// ^bb2_copy:
/// cf.br ^bb2
/// ```
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
SmallVector<cf::CondBranchOp> worklist;
function.walk([&](cf::CondBranchOp condBranch) {
if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
return isValidSMETileVectorType(value.getType());
})) {
worklist.push_back(condBranch);
}
});
auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
rewriter.setInsertionPointToEnd(source);
cf::BranchOp::create(rewriter, loc, dest, args);
};
for (auto condBranch : worklist) {
auto loc = condBranch.getLoc();
Block *block = condBranch->getBlock();
auto newTrueBranch = rewriter.splitBlock(block, block->end());
auto newFalseBranch = rewriter.splitBlock(block, block->end());
insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
condBranch.getFalseDestOperands());
rewriter.modifyOpInPlace(condBranch, [&] {
condBranch.getFalseDestOperandsMutable().clear();
condBranch.getTrueDestOperandsMutable().clear();
condBranch.setSuccessor(newTrueBranch, 0);
condBranch.setSuccessor(newFalseBranch, 1);
});
}
}
/// Inserts tile copies at `cf.br` operations.
///
/// BEFORE:
/// ```mlir
/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
/// ```
///
/// AFTER:
/// ```mlir
/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
/// ```
void insertCopiesAtBranches(IRRewriter &rewriter,
FunctionOpInterface function) {
for (Block &block : function.getBlocks()) {
Operation *terminator = block.getTerminator();
if (!isa<cf::BranchOp>(terminator))
continue;
rewriter.setInsertionPoint(terminator);
for (OpOperand &operand : terminator->getOpOperands()) {
if (isValidSMETileVectorType(operand.get().getType())) {
auto copy =
CopyTileOp::create(rewriter, terminator->getLoc(), operand.get());
rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
}
}
}
}
/// Prepares the IR for tile allocation. It does this by first 'splitting'
/// conditional branches (see `splitCondBranches`), then inserting tile copies
/// at branch operations. The conditional branches are split to prevent the
/// copies needed for them overlapping between the true and false paths of the
/// branch (see `tile-allocation-copies.mlir` and
/// `tile-allocation-liveness.mlir` for examples). The copies break up live
/// ranges and ensure when moving out of SSA the semantics of the program are
/// preserved.
void preprocessForTileAllocation(IRRewriter &rewriter,
FunctionOpInterface function) {
splitCondBranches(rewriter, function);
insertCopiesAtBranches(rewriter, function);
}
/// A live range for a (collection of) tile values. A live range is built up of
/// non-overlapping intervals [start, end) which represent parts of the program
/// where a value in the range needs to be live (i.e. in an SME virtual tile).
/// Note that as the intervals are non-overlapping all values within a live
/// range can be allocated to the same SME virtual tile.
struct LiveRange {
using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
llvm::IntervalMapHalfOpenInfo<unsigned>>;
using Allocator = RangeSet::Allocator;
// Dummy value for the IntervalMap. Only the keys matter (the intervals).
static constexpr uint8_t kValidLiveRange = 0xff;
LiveRange(Allocator &allocator)
: ranges(std::make_unique<RangeSet>(allocator)) {}
/// Returns true if this range overlaps with `otherRange`.
bool overlaps(LiveRange const &otherRange) const {
return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
*otherRange.ranges)
.valid();
}
/// Returns true if this range is active at `point` in the program.
bool overlaps(uint64_t point) const {
return ranges->lookup(point) == kValidLiveRange;
}
/// Unions this live range with `otherRange`, aborts if the ranges overlap.
void unionWith(LiveRange const &otherRange) {
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
++it)
ranges->insert(it.start(), it.stop(), kValidLiveRange);
values.set_union(otherRange.values);
}
/// Inserts an interval [start, end) for `value` into this range.
void insert(Value value, unsigned start, unsigned end) {
values.insert(value);
if (start != end)
ranges->insert(start, end, kValidLiveRange);
}
bool empty() const { return ranges->empty(); }
unsigned start() const { return ranges->start(); }
unsigned end() const { return ranges->stop(); }
bool operator<(LiveRange const &other) const {
return start() < other.start();
}
ArmSMETileType getTileType() const {
return *getSMETileType(cast<VectorType>(values[0].getType()));
}
/// The values contained in this live range.
SetVector<Value> values;
/// A set of (non-overlapping) intervals that mark where any value in `values`
/// is live.
std::unique_ptr<RangeSet> ranges;
/// The tile ID (or none) assigned to this live range.
std::optional<unsigned> tileId;
};
/// Number operations within a function to allow computing live ranges.
/// Operations are numbered consecutively wihin blocks, and the blocks are
/// topologically sorted (using forward edges). This function is only correct if
/// all ArmSME have been converted to CF (which is asserted).
DenseMap<Operation *, unsigned>
generateOperationNumbering(FunctionOpInterface function) {
unsigned index = 0;
SetVector<Block *> blocks =
getBlocksSortedByDominance(function.getFunctionBody());
DenseMap<Operation *, unsigned> operationToIndexMap;
for (Block *block : blocks) {
index++; // We want block args to have their own number.
for (Operation &op : block->getOperations()) {
#ifndef NDEBUG
op.walk([&](ArmSMETileOpInterface nestedOp) {
assert(&op == nestedOp.getOperation() &&
"ArmSME tile allocation does not support nested regions");
});
#endif
operationToIndexMap.try_emplace(&op, index++);
}
}
return operationToIndexMap;
}
/// Gather live ranges for SME tiles from the MLIR liveness analysis.
DenseMap<Value, LiveRange>
gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
LiveRange::Allocator &liveRangeAllocator,
Liveness &liveness, FunctionOpInterface function) {
assert(!operationToIndexMap.empty() && "expected operation numbering");
DenseMap<Value, LiveRange> liveRanges;
/// Defines or updates a live range for an SME tile value. Live-ins may update
/// an existing live range (rather than define a new one). Note: If
/// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
/// the block.
auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
LivenessBlockInfo const &livenessInfo,
bool liveAtBlockEntry = false) {
if (!isValidSMETileVectorType(value.getType()))
return;
// Find or create a live range for `value`.
auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
LiveRange &valueLiveRange = it->second;
auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
unsigned startOpIdx =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
valueLiveRange.insert(value, startOpIdx, endOpIdx);
};
for (Block &block : function.getBlocks()) {
LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
// Handle block arguments:
for (Value argument : block.getArguments())
defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
/*liveAtBlockEntry=*/true);
// Handle live-ins:
for (Value liveIn : livenessInfo->in())
defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
/*liveAtBlockEntry=*/true);
// Handle new definitions:
for (Operation &op : block) {
for (Value result : op.getResults())
defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
}
}
return liveRanges;
}
/// Iterate over all predecessor tile values to a (tile) block argument.
static void forEachPredecessorTileValue(BlockArgument blockArg,
function_ref<void(Value)> callback) {
Block *block = blockArg.getOwner();
unsigned argNumber = blockArg.getArgNumber();
for (Block *pred : block->getPredecessors()) {
TypeSwitch<Operation *>(pred->getTerminator())
.Case<cf::BranchOp>([&](auto branch) {
Value predecessorOperand = branch.getDestOperands()[argNumber];
callback(predecessorOperand);
})
.Case<cf::CondBranchOp>([&](auto condBranch) {
if (condBranch.getFalseDest() == block) {
Value predecessorOperand =
condBranch.getFalseDestOperands()[argNumber];
callback(predecessorOperand);
}
if (condBranch.getTrueDest() == block) {
Value predecessorOperand =
condBranch.getTrueDestOperands()[argNumber];
callback(predecessorOperand);
}
});
}
}
/// Coalesce live ranges where it would prevent unnecessary tile moves.
SmallVector<LiveRange *>
coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
DenseMap<Value, LiveRange *> liveRanges;
for (auto &[value, liveRange] : initialLiveRanges) {
liveRanges.insert({value, &liveRange});
}
// Merge the live ranges of values `a` and `b` into one (if they do not
// overlap). After this, the values `a` and `b` will both point to the same
// live range (which will contain multiple values).
auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
LiveRange *aLiveRange = liveRanges.at(a);
LiveRange *bLiveRange = liveRanges.at(b);
if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
aLiveRange->unionWith(*bLiveRange);
for (Value value : bLiveRange->values)
liveRanges[value] = aLiveRange;
}
};
// Merge the live ranges of new definitions with their tile operands.
auto unifyDefinitionsWithOperands = [&](Value value) {
auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
if (!armSMEOp)
return;
for (auto operand : armSMEOp->getOperands()) {
if (isValidSMETileVectorType(operand.getType()))
mergeValuesIfNonOverlapping(value, operand);
}
};
// Merge the live ranges of block arguments with their predecessors.
auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg)
return;
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
mergeValuesIfNonOverlapping(blockArg, predecessorTile);
});
};
auto applyRule = [&](auto rule) {
llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
};
// Unify as many live ranges as we can. This prevents unnecessary moves.
applyRule(unifyBlockArgumentsWithPredecessors);
applyRule(unifyDefinitionsWithOperands);
// Remove duplicate live range entries.
SetVector<LiveRange *> uniqueLiveRanges;
for (auto [_, liveRange] : liveRanges) {
if (!liveRange->empty())
uniqueLiveRanges.insert(liveRange);
}
// Sort the new live ranges by starting point (ready for tile allocation).
auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
llvm::sort(coalescedLiveRanges,
[](LiveRange *a, LiveRange *b) { return *a < *b; });
return std::move(coalescedLiveRanges);
}
/// Choose a live range to spill (via some heuristics). This picks either a live
/// range from `overlappingRanges`, or the new live range `newRange`.
template <typename OverlappingRangesIterator>
LiveRange *
chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
LiveRange *newRange) {
// Heuristic: Spill trivially copyable operations (usually free).
auto isTrivialSpill = [&](LiveRange &allocatedRange) {
return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
newRange->getTileType()) &&
allocatedRange.values.size() == 1 &&
isTriviallyCloneableTileOp(
allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
};
if (isTrivialSpill(*newRange))
return newRange;
auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
if (trivialSpill != overlappingRanges.end())
return &*trivialSpill;
// Heuristic: Spill the range that ends last (with a compatible tile type).
auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
a.end() < b.end();
};
LiveRange &latestEndingLiveRange =
*llvm::max_element(overlappingRanges, isSmallerTileTypeOrEndsEarlier);
if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
return &latestEndingLiveRange;
return newRange;
}
/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
void allocateTilesToLiveRanges(
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
TileAllocator tileAllocator;
// `activeRanges` = Live ranges that need to be in a tile at the
// `currentPoint` in the program.
SetVector<LiveRange *> activeRanges;
// `inactiveRanges` = Live ranges that _do not_ need to be in a tile
// at the `currentPoint` in the program but could become active again later.
// An inactive section of a live range can be seen as a 'hole' in the live
// range, where it is possible to reuse the live range's tile ID _before_ it
// has ended. By identifying 'holes', the allocator can reuse tiles more
// often, which helps avoid costly tile spills.
SetVector<LiveRange *> inactiveRanges;
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
auto currentPoint = nextRange->start();
// 1. Update the `activeRanges` at `currentPoint`.
activeRanges.remove_if([&](LiveRange *activeRange) {
// Check for live ranges that have expired.
if (activeRange->end() <= currentPoint) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
return true;
}
// Check for live ranges that have become inactive.
if (!activeRange->overlaps(currentPoint)) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
inactiveRanges.insert(activeRange);
return true;
}
return false;
});
// 2. Update the `inactiveRanges` at `currentPoint`.
inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
// Check for live ranges that have expired.
if (inactiveRange->end() <= currentPoint) {
return true;
}
// Check for live ranges that have become active.
if (inactiveRange->overlaps(currentPoint)) {
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
activeRanges.insert(inactiveRange);
return true;
}
return false;
});
// 3. Collect inactive live ranges that overlap with the new live range.
// Note: The overlap checks in steps 1 and 2 only look at the `currentPoint`
// whereas this checks if there is an overlap at any future point too.
SmallVector<LiveRange *> overlappingInactiveRanges;
for (LiveRange *inactiveRange : inactiveRanges) {
if (inactiveRange->overlaps(*nextRange)) {
// We need to reserve the tile IDs of overlapping inactive ranges to
// prevent two (overlapping) live ranges from getting the same tile ID.
tileAllocator.acquireTileId(inactiveRange->getTileType(),
*inactiveRange->tileId);
overlappingInactiveRanges.push_back(inactiveRange);
}
}
// 4. Allocate a tile ID to `nextRange`.
auto rangeTileType = nextRange->getTileType();
auto tileId = tileAllocator.allocateTileId(rangeTileType);
if (succeeded(tileId)) {
nextRange->tileId = *tileId;
} else {
// Create an iterator over all overlapping live ranges.
auto allOverlappingRanges = llvm::concat<LiveRange>(
llvm::make_pointee_range(activeRanges.getArrayRef()),
llvm::make_pointee_range(overlappingInactiveRanges));
// Choose an overlapping live range to spill.
LiveRange *rangeToSpill =
chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
if (rangeToSpill != nextRange) {
// Spill an (in)active live range (so release its tile ID first).
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
*rangeToSpill->tileId);
// This will always succeed after a spill (of an active live range).
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
// Remove the live range from the active/inactive sets.
if (!activeRanges.remove(rangeToSpill)) {
bool removed = inactiveRanges.remove(rangeToSpill);
assert(removed && "expected a range to be removed!");
(void)removed;
}
}
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
}
// 5. Insert the live range into the active ranges.
if (nextRange->tileId < kInMemoryTileIdBase)
activeRanges.insert(nextRange);
// 6. Release tiles reserved for inactive live ranges (in step 3).
for (LiveRange *range : overlappingInactiveRanges) {
if (*range->tileId < kInMemoryTileIdBase)
tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
}
}
}
/// Assigns a tile ID to an MLIR value.
void assignTileIdToValue(IRRewriter &rewriter, Value value,
IntegerAttr tileIdAttr) {
if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
for (Operation *user : value.getUsers()) {
if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
// Ensure ArmSME ops that don't produce a value still get a tile ID.
if (!hasTileResult(tileOp))
rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
}
}
}
/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
LogicalResult assignTileIdsAndResolveTrivialConflicts(
IRRewriter &rewriter, FunctionOpInterface function,
ArrayRef<LiveRange *> allocatedLiveRanges) {
for (LiveRange const *liveRange : allocatedLiveRanges) {
auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
auto isAllocatedToSameTile = [&](Value value) {
if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
tileOp && tileOp.getTileId() == tileIdAttr)
return true;
return liveRange->values.contains(value);
};
/// Eliminates copies where the operand has the same tile ID.
auto foldRedundantCopies = [&](Value value) -> LogicalResult {
auto copyOp = value.getDefiningOp<CopyTileOp>();
if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
return failure();
rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
return success();
};
/// Validates each predecessor to a tile block argument has been assigned
/// the same tile ID.
auto validateBlockArguments = [&](Value value) {
auto blockArg = dyn_cast<BlockArgument>(value);
if (!blockArg) {
// Not a block argument (nothing to validate).
return success();
}
bool tileMismatch = false;
forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
if (tileMismatch)
return;
if (!isAllocatedToSameTile(predecessorTile)) {
blockArg.getOwner()->getParentOp()->emitOpError(
"block argument not allocated to the same SME virtial tile as "
"predecessors");
tileMismatch = true;
}
});
return success(/*isSuccess=*/!tileMismatch);
};
/// Attempts to resolve (trivial) tile ID conflicts.
auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
OpOperand *tileOperand = getTileOpOperand(tileOp);
if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
// Operand already allocated to the correct tile.
// No conflict to resolve.
return success();
}
auto operandTileOp =
tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
if (!isTriviallyCloneableTileOp(operandTileOp)) {
auto error =
tileOp.emitOpError("tile operand allocated to different SME "
"virtial tile (move required)");
error.attachNote(tileOperand->get().getLoc())
<< "tile operand is: " << tileOperand->get();
return error;
}
// Cloning prevents a move/spill (though may require recomputation).
rewriter.setInsertionPoint(tileOp);
auto clonedOp = operandTileOp.clone();
rewriter.modifyOpInPlace(clonedOp,
[&] { clonedOp.setTileId(tileOp.getTileId()); });
rewriter.insert(clonedOp);
if (isa<CopyTileOp>(tileOp)) {
rewriter.replaceAllUsesWith(tileOp->getResult(0),
clonedOp->getResult(0));
} else {
rewriter.modifyOpInPlace(
tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
}
return success();
};
for (Value value : liveRange->values) {
// 1. Assign the tile ID to the value.
assignTileIdToValue(rewriter, value, tileIdAttr);
// 2. Attempt to eliminate redundant tile copies.
if (succeeded(foldRedundantCopies(value)))
continue;
// 3. Validate tile block arguments.
if (failed(validateBlockArguments(value)))
return failure();
// 4. Attempt to resolve (trivial) tile ID conflicts.
if (failed(resolveTrivialTileConflicts(value)))
return failure();
}
}
return success();
}
/// Prints live ranges alongside operation names for debugging.
void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
ArrayRef<LiveRange const *> liveRanges,
FunctionOpInterface function) {
llvm::errs() << "SME Tile Liveness: @" << function.getName()
<< "\nKey:\nS - Start\nE - End\n| - Live\n";
for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
llvm::errs() << "^bb" << blockIdx << ":\n";
for (Operation &op : block.getOperations()) {
unsigned operationIndex = operationToIndexMap.at(&op);
for (LiveRange const *range : liveRanges) {
char liveness = ' ';
for (auto it = range->ranges->begin(); it != range->ranges->end();
++it) {
if (it.start() == operationIndex)
liveness = (liveness == 'E' ? '|' : 'S');
else if (it.stop() == operationIndex)
liveness = (liveness == 'S' ? '|' : 'E');
else if (operationIndex >= it.start() && operationIndex < it.stop())
liveness = '|';
}
llvm::errs() << liveness;
}
llvm::errs() << ' ' << op.getName() << '\n';
}
}
llvm::errs() << "==========\n";
}
struct TestTileAllocationPass
: public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
FunctionOpInterface function = getOperation();
if (preprocessOnly) {
IRRewriter rewriter(function);
return preprocessForTileAllocation(rewriter, function);
}
if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
signalPassFailure();
}
};
} // namespace
LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
bool dumpRanges) {
if (function.empty()) {
// TODO: Also return early if the function contains no ArmSME ops?
return success();
}
LiveRange::Allocator liveRangeAllocator;
IRRewriter rewriter(function.getContext());
// 1. Preprocess the IR for tile allocation.
preprocessForTileAllocation(rewriter, function);
// 2. Gather live ranges for each ArmSME tile within the function.
Liveness liveness(function);
auto operationToIndexMap = generateOperationNumbering(function);
auto initialLiveRanges = gatherTileLiveRanges(
operationToIndexMap, liveRangeAllocator, liveness, function);
if (initialLiveRanges.empty())
return success();
if (dumpRanges) {
// Wrangle initial live ranges into a form suitable for printing.
auto nonEmpty = llvm::make_filter_range(
llvm::make_second_range(initialLiveRanges),
[&](LiveRange const &liveRange) { return !liveRange.empty(); });
auto initialRanges = llvm::to_vector(llvm::map_range(
nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
llvm::sort(initialRanges,
[](LiveRange const *a, LiveRange const *b) { return *a < *b; });
llvm::errs() << "\n========== Initial Live Ranges:\n";
dumpLiveRanges(operationToIndexMap, initialRanges, function);
}
// 3. Coalesce (non-overlapping) live ranges where it would be beneficial
// for tile allocation. E.g. Unify the result of an operation with its
// operands.
auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
if (dumpRanges) {
llvm::errs() << "\n========== Coalesced Live Ranges:\n";
dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
}
// 4. Allocate tile IDs to live ranges.
allocateTilesToLiveRanges(coalescedLiveRanges);
// 5. Assign the tile IDs back to the ArmSME operations.
if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
coalescedLiveRanges))) {
return failure();
}
// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
// users). This prevents the LLVM conversion needlessly inserting spills.
eraseTriviallyDeadTileOps(rewriter, function);
return success();
}