| //===- AsyncRuntimeRefCounting.cpp - Async Runtime Ref Counting -----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements automatic reference counting for Async runtime |
| // operations and types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "mlir/Analysis/Liveness.h" |
| #include "mlir/Dialect/Async/IR/Async.h" |
| #include "mlir/Dialect/Async/Passes.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/SmallSet.h" |
| |
| using namespace mlir; |
| using namespace mlir::async; |
| |
| #define DEBUG_TYPE "async-runtime-ref-counting" |
| |
| //===----------------------------------------------------------------------===// |
| // Utility functions shared by reference counting passes. |
| //===----------------------------------------------------------------------===// |
| |
| // Drop the reference count immediately if the value has no uses. |
| static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { |
| if (!value.getUses().empty()) |
| return failure(); |
| |
| OpBuilder b(value.getContext()); |
| |
| // Set insertion point after the operation producing a value, or at the |
| // beginning of the block if the value defined by the block argument. |
| if (Operation *op = value.getDefiningOp()) |
| b.setInsertionPointAfter(op); |
| else |
| b.setInsertionPointToStart(value.getParentBlock()); |
| |
| b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); |
| return success(); |
| } |
| |
| // Calls `addRefCounting` for every reference counted value defined by the |
| // operation `op` (block arguments and values defined in nested regions). |
| static LogicalResult walkReferenceCountedValues( |
| Operation *op, llvm::function_ref<LogicalResult(Value)> addRefCounting) { |
| // Check that we do not have high level async operations in the IR because |
| // otherwise reference counting will produce incorrect results after high |
| // level async operations will be lowered to `async.runtime` |
| WalkResult checkNoAsyncWalk = op->walk([&](Operation *op) -> WalkResult { |
| if (!isa<ExecuteOp, AwaitOp, AwaitAllOp, YieldOp>(op)) |
| return WalkResult::advance(); |
| |
| return op->emitError() |
| << "async operations must be lowered to async runtime operations"; |
| }); |
| |
| if (checkNoAsyncWalk.wasInterrupted()) |
| return failure(); |
| |
| // Add reference counting to block arguments. |
| WalkResult blockWalk = op->walk([&](Block *block) -> WalkResult { |
| for (BlockArgument arg : block->getArguments()) |
| if (isRefCounted(arg.getType())) |
| if (failed(addRefCounting(arg))) |
| return WalkResult::interrupt(); |
| |
| return WalkResult::advance(); |
| }); |
| |
| if (blockWalk.wasInterrupted()) |
| return failure(); |
| |
| // Add reference counting to operation results. |
| WalkResult opWalk = op->walk([&](Operation *op) -> WalkResult { |
| for (unsigned i = 0; i < op->getNumResults(); ++i) |
| if (isRefCounted(op->getResultTypes()[i])) |
| if (failed(addRefCounting(op->getResult(i)))) |
| return WalkResult::interrupt(); |
| |
| return WalkResult::advance(); |
| }); |
| |
| if (opWalk.wasInterrupted()) |
| return failure(); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Automatic reference counting based on the liveness analysis. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| class AsyncRuntimeRefCountingPass |
| : public AsyncRuntimeRefCountingBase<AsyncRuntimeRefCountingPass> { |
| public: |
| AsyncRuntimeRefCountingPass() = default; |
| void runOnOperation() override; |
| |
| private: |
| /// Adds an automatic reference counting to the `value`. |
| /// |
| /// All values (token, group or value) are semantically created with a |
| /// reference count of +1 and it is the responsibility of the async value user |
| /// to place the `add_ref` and `drop_ref` operations to ensure that the value |
| /// is destroyed after the last use. |
| /// |
| /// The function returns failure if it can't deduce the locations where |
| /// to place the reference counting operations. |
| /// |
| /// Async values "semantically created" when: |
| /// 1. Operation returns async result (e.g. `async.runtime.create`) |
| /// 2. Async value passed in as a block argument (or function argument, |
| /// because function arguments are just entry block arguments) |
| /// |
| /// Passing async value as a function argument (or block argument) does not |
| /// really mean that a new async value is created, it only means that the |
| /// caller of a function transfered ownership of `+1` reference to the callee. |
| /// It is convenient to think that from the callee perspective async value was |
| /// "created" with `+1` reference by the block argument. |
| /// |
| /// Automatic reference counting algorithm outline: |
| /// |
| /// #1 Insert `drop_ref` operations after last use of the `value`. |
| /// #2 Insert `add_ref` operations before functions calls with reference |
| /// counted `value` operand (newly created `+1` reference will be |
| /// transferred to the callee). |
| /// #3 Verify that divergent control flow does not lead to leaked reference |
| /// counted objects. |
| /// |
| /// Async runtime reference counting optimization pass will optimize away |
| /// some of the redundant `add_ref` and `drop_ref` operations inserted by this |
| /// strategy (see `async-runtime-ref-counting-opt`). |
| LogicalResult addAutomaticRefCounting(Value value); |
| |
| /// (#1) Adds the `drop_ref` operation after the last use of the `value` |
| /// relying on the liveness analysis. |
| /// |
| /// If the `value` is in the block `liveIn` set and it is not in the block |
| /// `liveOut` set, it means that it "dies" in the block. We find the last |
| /// use of the value in such block and: |
| /// |
| /// 1. If the last user is a `ReturnLike` operation we do nothing, because |
| /// it forwards the ownership to the caller. |
| /// 2. Otherwise we add a `drop_ref` operation immediately after the last |
| /// use. |
| LogicalResult addDropRefAfterLastUse(Value value); |
| |
| /// (#2) Adds the `add_ref` operation before the function call taking `value` |
| /// operand to ensure that the value passed to the function entry block |
| /// has a `+1` reference count. |
| LogicalResult addAddRefBeforeFunctionCall(Value value); |
| |
| /// (#3) Adds the `drop_ref` operation to account for successor blocks with |
| /// divergent `liveIn` property: `value` is not in the `liveIn` set of all |
| /// successor blocks. |
| /// |
| /// Example: |
| /// |
| /// ^entry: |
| /// %token = async.runtime.create : !async.token |
| /// cond_br %cond, ^bb1, ^bb2 |
| /// ^bb1: |
| /// async.runtime.await %token |
| /// async.runtime.drop_ref %token |
| /// br ^bb2 |
| /// ^bb2: |
| /// return |
| /// |
| /// In this example ^bb2 does not have `value` in the `liveIn` set, so we have |
| /// to branch into a special "reference counting block" from the ^entry that |
| /// will have a `drop_ref` operation, and then branch into the ^bb2. |
| /// |
| /// After transformation: |
| /// |
| /// ^entry: |
| /// %token = async.runtime.create : !async.token |
| /// cond_br %cond, ^bb1, ^reference_counting |
| /// ^bb1: |
| /// async.runtime.await %token |
| /// async.runtime.drop_ref %token |
| /// br ^bb2 |
| /// ^reference_counting: |
| /// async.runtime.drop_ref %token |
| /// br ^bb2 |
| /// ^bb2: |
| /// return |
| /// |
| /// An exception to this rule are blocks with `async.coro.suspend` terminator, |
| /// because in Async to LLVM lowering it is guaranteed that the control flow |
| /// will jump into the resume block, and then follow into the cleanup and |
| /// suspend blocks. |
| /// |
| /// Example: |
| /// |
| /// ^entry(%value: !async.value<f32>): |
| /// async.runtime.await_and_resume %value, %hdl : !async.value<f32> |
| /// async.coro.suspend %ret, ^suspend, ^resume, ^cleanup |
| /// ^resume: |
| /// %0 = async.runtime.load %value |
| /// br ^cleanup |
| /// ^cleanup: |
| /// ... |
| /// ^suspend: |
| /// ... |
| /// |
| /// Although cleanup and suspend blocks do not have the `value` in the |
| /// `liveIn` set, it is guaranteed that execution will eventually continue in |
| /// the resume block (we never explicitly destroy coroutines). |
| LogicalResult addDropRefInDivergentLivenessSuccessor(Value value); |
| }; |
| |
| } // namespace |
| |
| LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { |
| OpBuilder builder(value.getContext()); |
| Location loc = value.getLoc(); |
| |
| // Use liveness analysis to find the placement of `drop_ref`operation. |
| auto &liveness = getAnalysis<Liveness>(); |
| |
| // We analyse only the blocks of the region that defines the `value`, and do |
| // not check nested blocks attached to operations. |
| // |
| // By analyzing only the `definingRegion` CFG we potentially loose an |
| // opportunity to drop the reference count earlier and can extend the lifetime |
| // of reference counted value longer then it is really required. |
| // |
| // We also assume that all nested regions finish their execution before the |
| // completion of the owner operation. The only exception to this rule is |
| // `async.execute` operation, and we verify that they are lowered to the |
| // `async.runtime` operations before adding automatic reference counting. |
| Region *definingRegion = value.getParentRegion(); |
| |
| // Last users of the `value` inside all blocks where the value dies. |
| llvm::SmallSet<Operation *, 4> lastUsers; |
| |
| // Find blocks in the `definingRegion` that have users of the `value` (if |
| // there are multiple users in the block, which one will be selected is |
| // undefined). User operation might be not the actual user of the value, but |
| // the operation in the block that has a "real user" in one of the attached |
| // regions. |
| llvm::DenseMap<Block *, Operation *> usersInTheBlocks; |
| |
| for (Operation *user : value.getUsers()) { |
| Block *userBlock = user->getBlock(); |
| Block *ancestor = definingRegion->findAncestorBlockInRegion(*userBlock); |
| usersInTheBlocks[ancestor] = ancestor->findAncestorOpInBlock(*user); |
| assert(ancestor && "ancestor block must be not null"); |
| assert(usersInTheBlocks[ancestor] && "ancestor op must be not null"); |
| } |
| |
| // Find blocks where the `value` dies: the value is in `liveIn` set and not |
| // in the `liveOut` set. We place `drop_ref` immediately after the last use |
| // of the `value` in such regions (after handling few special cases). |
| // |
| // We do not traverse all the blocks in the `definingRegion`, because the |
| // `value` can be in the live in set only if it has users in the block, or it |
| // is defined in the block. |
| // |
| // Values with zero users (only definition) handled explicitly above. |
| for (auto &blockAndUser : usersInTheBlocks) { |
| Block *block = blockAndUser.getFirst(); |
| Operation *userInTheBlock = blockAndUser.getSecond(); |
| |
| const LivenessBlockInfo *blockLiveness = liveness.getLiveness(block); |
| |
| // Value must be in the live input set or defined in the block. |
| assert(blockLiveness->isLiveIn(value) || |
| blockLiveness->getBlock() == value.getParentBlock()); |
| |
| // If value is in the live out set, it means it doesn't "die" in the block. |
| if (blockLiveness->isLiveOut(value)) |
| continue; |
| |
| // At this point we proved that `value` dies in the `block`. Find the last |
| // use of the `value` inside the `block`, this is where it "dies". |
| Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); |
| assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); |
| lastUsers.insert(lastUser); |
| } |
| |
| // Process all the last users of the `value` inside each block where the value |
| // dies. |
| for (Operation *lastUser : lastUsers) { |
| // Return like operations forward reference count. |
| if (lastUser->hasTrait<OpTrait::ReturnLike>()) |
| continue; |
| |
| // We can't currently handle other types of terminators. |
| if (lastUser->hasTrait<OpTrait::IsTerminator>()) |
| return lastUser->emitError() << "async reference counting can't handle " |
| "terminators that are not ReturnLike"; |
| |
| // Add a drop_ref immediately after the last user. |
| builder.setInsertionPointAfter(lastUser); |
| builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult |
| AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { |
| OpBuilder builder(value.getContext()); |
| Location loc = value.getLoc(); |
| |
| for (Operation *user : value.getUsers()) { |
| if (!isa<CallOp>(user)) |
| continue; |
| |
| // Add a reference before the function call to pass the value at `+1` |
| // reference to the function entry block. |
| builder.setInsertionPoint(user); |
| builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult |
| AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( |
| Value value) { |
| using BlockSet = llvm::SmallPtrSet<Block *, 4>; |
| |
| OpBuilder builder(value.getContext()); |
| |
| // If a block has successors with different `liveIn` property of the `value`, |
| // record block successors that do not thave the `value` in the `liveIn` set. |
| llvm::SmallDenseMap<Block *, BlockSet> divergentLivenessBlocks; |
| |
| // Use liveness analysis to find the placement of `drop_ref`operation. |
| auto &liveness = getAnalysis<Liveness>(); |
| |
| // Because we only add `drop_ref` operations to the region that defines the |
| // `value` we can only process CFG for the same region. |
| Region *definingRegion = value.getParentRegion(); |
| |
| // Collect blocks with successors with mismatching `liveIn` sets. |
| for (Block &block : definingRegion->getBlocks()) { |
| const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); |
| |
| // Skip the block if value is not in the `liveOut` set. |
| if (!blockLiveness || !blockLiveness->isLiveOut(value)) |
| continue; |
| |
| BlockSet liveInSuccessors; // `value` is in `liveIn` set |
| BlockSet noLiveInSuccessors; // `value` is not in the `liveIn` set |
| |
| // Collect successors that do not have `value` in the `liveIn` set. |
| for (Block *successor : block.getSuccessors()) { |
| const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); |
| if (succLiveness && succLiveness->isLiveIn(value)) |
| liveInSuccessors.insert(successor); |
| else |
| noLiveInSuccessors.insert(successor); |
| } |
| |
| // Block has successors with different `liveIn` property of the `value`. |
| if (!liveInSuccessors.empty() && !noLiveInSuccessors.empty()) |
| divergentLivenessBlocks.try_emplace(&block, noLiveInSuccessors); |
| } |
| |
| // Try to insert `dropRef` operations to handle blocks with divergent liveness |
| // in successors blocks. |
| for (auto kv : divergentLivenessBlocks) { |
| Block *block = kv.getFirst(); |
| BlockSet &successors = kv.getSecond(); |
| |
| // Coroutine suspension is a special case terminator for wich we do not |
| // need to create additional reference counting (see details above). |
| Operation *terminator = block->getTerminator(); |
| if (isa<CoroSuspendOp>(terminator)) |
| continue; |
| |
| // We only support successor blocks with empty block argument list. |
| auto hasArgs = [](Block *block) { return !block->getArguments().empty(); }; |
| if (llvm::any_of(successors, hasArgs)) |
| return terminator->emitOpError() |
| << "successor have different `liveIn` property of the reference " |
| "counted value"; |
| |
| // Make sure that `dropRef` operation is called when branched into the |
| // successor block without `value` in the `liveIn` set. |
| for (Block *successor : successors) { |
| // If successor has a unique predecessor, it is safe to create `dropRef` |
| // operations directly in the successor block. |
| // |
| // Otherwise we need to create a special block for reference counting |
| // operations, and branch from it to the original successor block. |
| Block *refCountingBlock = nullptr; |
| |
| if (successor->getUniquePredecessor() == block) { |
| refCountingBlock = successor; |
| } else { |
| refCountingBlock = &successor->getParent()->emplaceBlock(); |
| refCountingBlock->moveBefore(successor); |
| OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); |
| builder.create<BranchOp>(value.getLoc(), successor); |
| } |
| |
| OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); |
| builder.create<RuntimeDropRefOp>(value.getLoc(), value, |
| builder.getI64IntegerAttr(1)); |
| |
| // No need to update the terminator operation. |
| if (successor == refCountingBlock) |
| continue; |
| |
| // Update terminator `successor` block to `refCountingBlock`. |
| for (auto pair : llvm::enumerate(terminator->getSuccessors())) |
| if (pair.value() == successor) |
| terminator->setSuccessor(refCountingBlock, pair.index()); |
| } |
| } |
| |
| return success(); |
| } |
| |
| LogicalResult |
| AsyncRuntimeRefCountingPass::addAutomaticRefCounting(Value value) { |
| // Short-circuit reference counting for values without uses. |
| if (succeeded(dropRefIfNoUses(value))) |
| return success(); |
| |
| // Add `drop_ref` operations based on the liveness analysis. |
| if (failed(addDropRefAfterLastUse(value))) |
| return failure(); |
| |
| // Add `add_ref` operations before function calls. |
| if (failed(addAddRefBeforeFunctionCall(value))) |
| return failure(); |
| |
| // Add `drop_ref` operations to successors with divergent `value` liveness. |
| if (failed(addDropRefInDivergentLivenessSuccessor(value))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| void AsyncRuntimeRefCountingPass::runOnOperation() { |
| auto functor = [&](Value value) { return addAutomaticRefCounting(value); }; |
| if (failed(walkReferenceCountedValues(getOperation(), functor))) |
| signalPassFailure(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Reference counting based on the user defined policy. |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| class AsyncRuntimePolicyBasedRefCountingPass |
| : public AsyncRuntimePolicyBasedRefCountingBase< |
| AsyncRuntimePolicyBasedRefCountingPass> { |
| public: |
| AsyncRuntimePolicyBasedRefCountingPass() { initializeDefaultPolicy(); } |
| |
| void runOnOperation() override; |
| |
| private: |
| // Adds a reference counting operations for all uses of the `value` according |
| // to the reference counting policy. |
| LogicalResult addRefCounting(Value value); |
| |
| void initializeDefaultPolicy(); |
| |
| llvm::SmallVector<std::function<FailureOr<int>(OpOperand &)>> policy; |
| }; |
| |
| } // namespace |
| |
| LogicalResult |
| AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { |
| // Short-circuit reference counting for values without uses. |
| if (succeeded(dropRefIfNoUses(value))) |
| return success(); |
| |
| OpBuilder b(value.getContext()); |
| |
| // Consult the user defined policy for every value use. |
| for (OpOperand &operand : value.getUses()) { |
| Location loc = operand.getOwner()->getLoc(); |
| |
| for (auto &func : policy) { |
| FailureOr<int> refCount = func(operand); |
| if (failed(refCount)) |
| return failure(); |
| |
| int cnt = refCount.getValue(); |
| |
| // Create `add_ref` operation before the operand owner. |
| if (cnt > 0) { |
| b.setInsertionPoint(operand.getOwner()); |
| b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); |
| } |
| |
| // Create `drop_ref` operation after the operand owner. |
| if (cnt < 0) { |
| b.setInsertionPointAfter(operand.getOwner()); |
| b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| void AsyncRuntimePolicyBasedRefCountingPass::initializeDefaultPolicy() { |
| policy.push_back([](OpOperand &operand) -> FailureOr<int> { |
| Operation *op = operand.getOwner(); |
| Type type = operand.get().getType(); |
| |
| bool isToken = type.isa<TokenType>(); |
| bool isGroup = type.isa<GroupType>(); |
| bool isValue = type.isa<ValueType>(); |
| |
| // Drop reference after async token or group error check (coro await). |
| if (auto await = dyn_cast<RuntimeIsErrorOp>(op)) |
| return (isToken || isGroup) ? -1 : 0; |
| |
| // Drop reference after async value load. |
| if (auto load = dyn_cast<RuntimeLoadOp>(op)) |
| return isValue ? -1 : 0; |
| |
| // Drop reference after async token added to the group. |
| if (auto add = dyn_cast<RuntimeAddToGroupOp>(op)) |
| return isToken ? -1 : 0; |
| |
| return 0; |
| }); |
| } |
| |
| void AsyncRuntimePolicyBasedRefCountingPass::runOnOperation() { |
| auto functor = [&](Value value) { return addRefCounting(value); }; |
| if (failed(walkReferenceCountedValues(getOperation(), functor))) |
| signalPassFailure(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| |
| std::unique_ptr<Pass> mlir::createAsyncRuntimeRefCountingPass() { |
| return std::make_unique<AsyncRuntimeRefCountingPass>(); |
| } |
| |
| std::unique_ptr<Pass> mlir::createAsyncRuntimePolicyBasedRefCountingPass() { |
| return std::make_unique<AsyncRuntimePolicyBasedRefCountingPass>(); |
| } |