blob: c9bb3395a9949184f24b14558599409e14114edd [file] [log] [blame]
//===- SuspendCrossingInfo.cpp - Utility for suspend crossing values ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// The SuspendCrossingInfo maintains data that allows to answer a question
// whether given two BasicBlocks A and B there is a path from A to B that
// passes through a suspend point. Note, SuspendCrossingInfo is invalidated
// by changes to the CFG including adding/removing BBs due to its use of BB
// ptrs in the BlockToIndexMapping.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Coroutines/SuspendCrossingInfo.h"
#include "llvm/IR/ModuleSlotTracker.h"
// The "coro-suspend-crossing" flag is very noisy. There is another debug type,
// "coro-frame", which results in leaner debug spew.
#define DEBUG_TYPE "coro-suspend-crossing"
namespace llvm {
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
static void dumpBasicBlockLabel(const BasicBlock *BB, ModuleSlotTracker &MST) {
if (BB->hasName()) {
dbgs() << BB->getName();
return;
}
dbgs() << MST.getLocalSlot(BB);
}
LLVM_DUMP_METHOD void
SuspendCrossingInfo::dump(StringRef Label, BitVector const &BV,
const ReversePostOrderTraversal<Function *> &RPOT,
ModuleSlotTracker &MST) const {
dbgs() << Label << ":";
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
if (BV[BBNo]) {
dbgs() << " ";
dumpBasicBlockLabel(BB, MST);
}
}
dbgs() << "\n";
}
LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const {
if (Block.empty())
return;
BasicBlock *const B = Mapping.indexToBlock(0);
Function *F = B->getParent();
ModuleSlotTracker MST(F->getParent());
MST.incorporateFunction(*F);
ReversePostOrderTraversal<Function *> RPOT(F);
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
dumpBasicBlockLabel(BB, MST);
dbgs() << ":\n";
dump(" Consumes", Block[BBNo].Consumes, RPOT, MST);
dump(" Kills", Block[BBNo].Kills, RPOT, MST);
}
dbgs() << "\n";
}
#endif
bool SuspendCrossingInfo::hasPathCrossingSuspendPoint(BasicBlock *From,
BasicBlock *To) const {
size_t const FromIndex = Mapping.blockToIndex(From);
size_t const ToIndex = Mapping.blockToIndex(To);
bool const Result = Block[ToIndex].Kills[FromIndex];
LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
<< " crosses suspend point\n");
return Result;
}
bool SuspendCrossingInfo::hasPathOrLoopCrossingSuspendPoint(
BasicBlock *From, BasicBlock *To) const {
size_t const FromIndex = Mapping.blockToIndex(From);
size_t const ToIndex = Mapping.blockToIndex(To);
bool Result = Block[ToIndex].Kills[FromIndex] ||
(From == To && Block[ToIndex].KillLoop);
LLVM_DEBUG(if (Result) dbgs() << From->getName() << " => " << To->getName()
<< " crosses suspend point (path or loop)\n");
return Result;
}
template <bool Initialize>
bool SuspendCrossingInfo::computeBlockData(
const ReversePostOrderTraversal<Function *> &RPOT) {
bool Changed = false;
for (const BasicBlock *BB : RPOT) {
auto BBNo = Mapping.blockToIndex(BB);
auto &B = Block[BBNo];
// We don't need to count the predecessors when initialization.
if constexpr (!Initialize)
// If all the predecessors of the current Block don't change,
// the BlockData for the current block must not change too.
if (all_of(predecessors(B), [this](BasicBlock *BB) {
return !Block[Mapping.blockToIndex(BB)].Changed;
})) {
B.Changed = false;
continue;
}
// Saved Consumes and Kills bitsets so that it is easy to see
// if anything changed after propagation.
auto SavedConsumes = B.Consumes;
auto SavedKills = B.Kills;
for (BasicBlock *PI : predecessors(B)) {
auto PrevNo = Mapping.blockToIndex(PI);
auto &P = Block[PrevNo];
// Propagate Kills and Consumes from predecessors into B.
B.Consumes |= P.Consumes;
B.Kills |= P.Kills;
// If block P is a suspend block, it should propagate kills into block
// B for every block P consumes.
if (P.Suspend)
B.Kills |= P.Consumes;
}
if (B.Suspend) {
// If block B is a suspend block, it should kill all of the blocks it
// consumes.
B.Kills |= B.Consumes;
} else if (B.End) {
// If block B is an end block, it should not propagate kills as the
// blocks following coro.end() are reached during initial invocation
// of the coroutine while all the data are still available on the
// stack or in the registers.
B.Kills.reset();
} else {
// This is reached when B block it not Suspend nor coro.end and it
// need to make sure that it is not in the kill set.
B.KillLoop |= B.Kills[BBNo];
B.Kills.reset(BBNo);
}
if constexpr (!Initialize) {
B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes);
Changed |= B.Changed;
}
}
return Changed;
}
SuspendCrossingInfo::SuspendCrossingInfo(
Function &F, const SmallVectorImpl<AnyCoroSuspendInst *> &CoroSuspends,
const SmallVectorImpl<AnyCoroEndInst *> &CoroEnds)
: Mapping(F) {
const size_t N = Mapping.size();
Block.resize(N);
// Initialize every block so that it consumes itself
for (size_t I = 0; I < N; ++I) {
auto &B = Block[I];
B.Consumes.resize(N);
B.Kills.resize(N);
B.Consumes.set(I);
B.Changed = true;
}
// Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as
// the code beyond coro.end is reachable during initial invocation of the
// coroutine.
for (auto *CE : CoroEnds) {
// Verify CoroEnd was normalized
assert(CE->getParent()->getFirstInsertionPt() == CE->getIterator() &&
CE->getParent()->size() <= 2 && "CoroEnd must be in its own BB");
getBlockData(CE->getParent()).End = true;
}
// Mark all suspend blocks and indicate that they kill everything they
// consume. Note, that crossing coro.save also requires a spill, as any code
// between coro.save and coro.suspend may resume the coroutine and all of the
// state needs to be saved by that time.
auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) {
BasicBlock *SuspendBlock = BarrierInst->getParent();
auto &B = getBlockData(SuspendBlock);
B.Suspend = true;
B.Kills |= B.Consumes;
};
for (auto *CSI : CoroSuspends) {
// Verify CoroSuspend was normalized
assert(CSI->getParent()->getFirstInsertionPt() == CSI->getIterator() &&
CSI->getParent()->size() <= 2 &&
"CoroSuspend must be in its own BB");
markSuspendBlock(CSI);
if (auto *Save = CSI->getCoroSave())
markSuspendBlock(Save);
}
// It is considered to be faster to use RPO traversal for forward-edges
// dataflow analysis.
ReversePostOrderTraversal<Function *> RPOT(&F);
computeBlockData</*Initialize=*/true>(RPOT);
while (computeBlockData</*Initialize*/ false>(RPOT))
;
LLVM_DEBUG(dump());
}
} // namespace llvm