| //===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This pass allocates SME tiles at the 'func.func' op level for ArmSME |
| // operations. It does this using a 16-bit tile mask that has a bit for each |
| // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule. |
| // |
| // The 128-bit tiles overlap with other element tiles as follows (see section |
| // B2.3.2 of SME spec [1]): |
| // |
| // Tile Overlaps |
| // --------------------------------------------------------------------------- |
| // ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, |
| // ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q |
| // ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q |
| // ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q |
| // ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q |
| // ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q |
| // ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q |
| // ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q |
| // ZA0.D ZA0.Q, ZA8.Q |
| // ZA1.D ZA1.Q, ZA9.Q |
| // ZA2.D ZA2.Q, ZA10.Q |
| // ZA3.D ZA3.Q, ZA11.Q |
| // ZA4.D ZA4.Q, ZA12.Q |
| // ZA5.D ZA5.Q, ZA13.Q |
| // ZA6.D ZA6.Q, ZA14.Q |
| // ZA7.D ZA7.Q, ZA15.Q |
| // |
| // The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use' |
| // that is initalized during the first tile allocation within a function and |
| // updated on each subsequent allocation. |
| // |
| // [1] https://developer.arm.com/documentation/ddi0616/aa |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/ArmSME/IR/ArmSME.h" |
| #include "mlir/Dialect/ArmSME/Transforms/Passes.h" |
| #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| #define DEBUG_TYPE "allocate-arm-sme-tiles" |
| |
| namespace mlir { |
| namespace arm_sme { |
| #define GEN_PASS_DEF_TILEALLOCATION |
| #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" |
| } // namespace arm_sme |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace mlir::arm_sme; |
| |
| namespace { |
| |
| static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use"); |
| static constexpr StringLiteral |
| kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id"); |
| |
| enum class TileMask : unsigned { |
| // clang-format off |
| kZA0B = 0xffff, // 1111 1111 1111 1111 |
| |
| kZA0H = 0xaaaa, // 1010 1010 1010 1010 |
| kZA1H = 0x5555, // 0101 0101 0101 0101 |
| |
| kZA0S = 0x8888, // 1000 1000 1000 1000 |
| kZA1S = 0x4444, // 0100 0100 0100 0100 |
| kZA2S = 0x2222, // 0010 0010 0010 0010 |
| kZA3S = 0x1111, // 0001 0001 0001 0001 |
| |
| kZA0D = 0x8080, // 1000 0000 1000 0000 |
| kZA1D = 0x4040, // 0100 0000 0100 0000 |
| kZA2D = 0x2020, // 0010 0000 0010 0000 |
| kZA3D = 0x1010, // 0001 0000 0001 0000 |
| kZA4D = 0x808, // 0000 1000 0000 1000 |
| kZA5D = 0x404, // 0000 0100 0000 0100 |
| kZA6D = 0x202, // 0000 0010 0000 0010 |
| kZA7D = 0x101, // 0000 0001 0000 0001 |
| |
| kZA0Q = 0x8000, // 1000 0000 0000 0000 |
| kZA1Q = 0x4000, // 0100 0000 0000 0000 |
| kZA2Q = 0x2000, // 0010 0000 0000 0000 |
| kZA3Q = 0x1000, // 0001 0000 0000 0000 |
| kZA4Q = 0x800, // 0000 1000 0000 0000 |
| kZA5Q = 0x400, // 0000 0100 0000 0000 |
| kZA6Q = 0x200, // 0000 0010 0000 0000 |
| kZA7Q = 0x100, // 0000 0001 0000 0000 |
| kZA8Q = 0x80, // 0000 0000 1000 0000 |
| kZA9Q = 0x40, // 0000 0000 0100 0000 |
| kZA10Q = 0x20, // 0000 0000 0010 0000 |
| kZA11Q = 0x10, // 0000 0000 0001 0000 |
| kZA12Q = 0x8, // 0000 0000 0000 1000 |
| kZA13Q = 0x4, // 0000 0000 0000 0100 |
| kZA14Q = 0x2, // 0000 0000 0000 0010 |
| kZA15Q = 0x1, // 0000 0000 0000 0001 |
| |
| kNone = 0x0, // 0000 0000 0000 0000 |
| // clang-format on |
| |
| LLVM_MARK_AS_BITMASK_ENUM(kZA0B) |
| }; |
| |
| /// Returns the set of masks relevant for the given type. |
| static ArrayRef<TileMask> getMasks(ArmSMETileType type) { |
| static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B}; |
| static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H}; |
| static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S, |
| TileMask::kZA2S, TileMask::kZA3S}; |
| static constexpr std::array ZA_D_MASKS = { |
| TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, |
| TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; |
| static constexpr std::array ZA_Q_MASKS = { |
| TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, |
| TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, |
| TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, |
| TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; |
| switch (type) { |
| case ArmSMETileType::ZAB: |
| return ZA_B_MASKS; |
| case ArmSMETileType::ZAH: |
| return ZA_H_MASKS; |
| case ArmSMETileType::ZAS: |
| return ZA_S_MASKS; |
| case ArmSMETileType::ZAD: |
| return ZA_D_MASKS; |
| case ArmSMETileType::ZAQ: |
| return ZA_Q_MASKS; |
| } |
| } |
| |
| /// Allocates and returns a tile ID. Returns an error if there are no tiles |
| /// left. |
| static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType, |
| TileMask &tilesInUse) { |
| auto masks = getMasks(tileType); |
| for (auto [tileId, tileMask] : llvm::enumerate(masks)) { |
| if ((tilesInUse & tileMask) == TileMask::kNone) { |
| tilesInUse |= tileMask; |
| return tileId; |
| } |
| } |
| return failure(); |
| } |
| |
| /// Collects transitive uses of a root value through control flow. This can |
| /// handle basic SCF constructs, along with control flow (br and cond_br). |
| /// Simple loops work at the SCF level, while more complex control flow can be |
| /// dealt with after lowering to CF. This is used to implement basic tile |
| /// allocation. |
| static void findDependantOps(Value rootValue, |
| SetVector<Operation *> &dependantOps) { |
| auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) { |
| for (auto [idx, value] : llvm::enumerate(inputValues)) { |
| if (value == rootValue) |
| findDependantOps(exitValues[idx], dependantOps); |
| } |
| }; |
| for (Operation *user : rootValue.getUsers()) { |
| if (dependantOps.contains(user)) |
| continue; |
| dependantOps.insert(user); |
| TypeSwitch<Operation *>(user) |
| .Case<cf::BranchOp>([&](auto branchOp) { |
| // (CF) Follow branch. |
| traverseCorrespondingValues(branchOp.getDestOperands(), |
| branchOp.getDest()->getArguments()); |
| }) |
| .Case<cf::CondBranchOp>([&](auto condBranchOp) { |
| // (CF) Follow true branch. |
| traverseCorrespondingValues( |
| condBranchOp.getTrueOperands(), |
| condBranchOp.getTrueDest()->getArguments()); |
| // (CF) Follow false branch. |
| traverseCorrespondingValues( |
| condBranchOp.getFalseOperands(), |
| condBranchOp.getFalseDest()->getArguments()); |
| }) |
| .Case<LoopLikeOpInterface>([&](auto loopOp) { |
| // (SCF) Follow iter_args of (basic) loops (e.g. for loops). |
| traverseCorrespondingValues(loopOp.getInits(), |
| loopOp.getRegionIterArgs()); |
| }) |
| .Case<scf::YieldOp>([&](auto yieldOp) { |
| // (SCF) Follow yields of (basic) control flow (e.g. for loops). |
| auto parent = user->getParentOp(); |
| traverseCorrespondingValues(user->getOperands(), |
| parent->getResults()); |
| }) |
| .Default([&](auto) { |
| // Otherwise, assume users of _any_ result are dependant. |
| for (Value result : user->getResults()) |
| findDependantOps(result, dependantOps); |
| }); |
| } |
| } |
| struct AssignTileIDsPattern |
| : public OpInterfaceRewritePattern<ArmSMETileOpInterface> { |
| using OpInterfaceRewritePattern::OpInterfaceRewritePattern; |
| LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp, |
| PatternRewriter &rewriter) const override { |
| if (tileOp.getTileId()) |
| return failure(); |
| |
| auto func = tileOp->getParentOfType<FunctionOpInterface>(); |
| auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { |
| if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>( |
| func->getDiscardableAttr(name))) |
| return unsigned(attr.getInt()); |
| return defaultVal; |
| }; |
| auto setDiscardableIntAttr = [&](StringRef name, auto value) { |
| rewriter.modifyOpInPlace(tileOp, [&] { |
| func->setDiscardableAttr(name, |
| rewriter.getI32IntegerAttr((unsigned)value)); |
| }); |
| }; |
| |
| std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType(); |
| if (!tileType) |
| return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile"); |
| |
| TileMask tilesInUse = |
| static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr)); |
| auto tileId = allocateTileId(*tileType, tilesInUse); |
| bool tileIsInMemory = failed(tileId); |
| if (tileIsInMemory) { |
| // If we could not find a real tile ID, use an in-memory tile ID (ID >= |
| // 16). A later pass will insert the necessary spills and reloads. |
| tileId = |
| getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); |
| tileOp->emitWarning( |
| "failed to allocate SME virtual tile to operation, all tile " |
| "operations will go through memory, expect degraded performance"); |
| } |
| |
| // Set all operations dependent on `tileOp` to use the same tile ID. |
| // This is a naive tile allocation scheme, but works for common cases. For |
| // example, as this only allocates tile IDs to existing ops, it can't solve |
| // cases like this (%tileA and %tileB come from different root operations): |
| // |
| // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> { |
| // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32> |
| // } else { |
| // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32> |
| // } |
| // |
| // This case would require allocating a new tile for the result of the |
| // scf.if, and moving the contents of %tileA or %tileB to result tile (based |
| // on the %some_cond). |
| // Find all the ops that (transitively) depend on this tile. |
| SetVector<Operation *> dependantOps; |
| findDependantOps(tileOp->getResult(0), dependantOps); |
| auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId); |
| for (auto *op : dependantOps) { |
| if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) { |
| auto currentTileId = dependantTileOp.getTileId(); |
| if (currentTileId && unsigned(currentTileId.getInt()) != tileId) |
| return dependantTileOp.emitOpError( |
| "already assigned different SME virtual tile!"); |
| } |
| } |
| |
| // Rewrite IR. |
| if (!tileIsInMemory) |
| setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); |
| else |
| setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); |
| rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); |
| for (auto *op : dependantOps) { |
| if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) { |
| rewriter.modifyOpInPlace( |
| dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); }); |
| } |
| } |
| |
| return success(); |
| } |
| }; |
| |
| struct TileAllocationPass |
| : public arm_sme::impl::TileAllocationBase<TileAllocationPass> { |
| void runOnOperation() override { |
| RewritePatternSet patterns(&getContext()); |
| patterns.add<AssignTileIDsPattern>(patterns.getContext()); |
| GreedyRewriteConfig config; |
| // Setting useTopDownTraversal ensures tiles are allocated in program |
| // order. |
| config.useTopDownTraversal = true; |
| if (mlir::failed(mlir::applyPatternsAndFoldGreedily( |
| getOperation(), std::move(patterns), config))) { |
| signalPassFailure(); |
| } |
| } |
| }; |
| } // namespace |
| |
| std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() { |
| return std::make_unique<TileAllocationPass>(); |
| } |