| //===- Tracker.cpp --------------------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/SandboxIR/Tracker.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Instruction.h" |
| #include "llvm/IR/StructuralHash.h" |
| #include "llvm/SandboxIR/Instruction.h" |
| |
| using namespace llvm::sandboxir; |
| |
| #ifndef NDEBUG |
| |
| std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const { |
| std::string Result; |
| raw_string_ostream SS(Result); |
| F.print(SS, /*AssemblyAnnotationWriter=*/nullptr); |
| return Result; |
| } |
| |
| IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const { |
| ContextSnapshot Result; |
| for (const auto &Entry : Ctx.LLVMModuleToModuleMap) |
| for (const auto &F : *Entry.first) { |
| FunctionSnapshot Snapshot; |
| Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true); |
| Snapshot.TextualIR = dumpIR(F); |
| Result[&F] = Snapshot; |
| } |
| return Result; |
| } |
| |
| bool IRSnapshotChecker::diff(const ContextSnapshot &Orig, |
| const ContextSnapshot &Curr) const { |
| bool DifferenceFound = false; |
| for (const auto &[F, OrigFS] : Orig) { |
| auto CurrFSIt = Curr.find(F); |
| if (CurrFSIt == Curr.end()) { |
| DifferenceFound = true; |
| dbgs() << "Function " << F->getName() << " not found in current IR.\n"; |
| dbgs() << OrigFS.TextualIR << "\n"; |
| continue; |
| } |
| const FunctionSnapshot &CurrFS = CurrFSIt->second; |
| if (OrigFS.Hash != CurrFS.Hash) { |
| DifferenceFound = true; |
| dbgs() << "Found IR difference in Function " << F->getName() << "\n"; |
| dbgs() << "Original:\n" << OrigFS.TextualIR << "\n"; |
| dbgs() << "Current:\n" << CurrFS.TextualIR << "\n"; |
| } |
| } |
| // Check that Curr doesn't contain any new functions. |
| for (const auto &[F, CurrFS] : Curr) { |
| if (!Orig.contains(F)) { |
| DifferenceFound = true; |
| dbgs() << "Function " << F->getName() |
| << " found in current IR but not in original snapshot.\n"; |
| dbgs() << CurrFS.TextualIR << "\n"; |
| } |
| } |
| return DifferenceFound; |
| } |
| |
| void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); } |
| |
| void IRSnapshotChecker::expectNoDiff() { |
| ContextSnapshot CurrContextSnapshot = takeSnapshot(); |
| if (diff(OrigContextSnapshot, CurrContextSnapshot)) { |
| llvm_unreachable( |
| "Original and current IR differ! Probably a checkpointing bug."); |
| } |
| } |
| |
| void UseSet::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| |
| void UseSwap::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx) |
| : PHI(PHI), RemovedIdx(RemovedIdx) { |
| RemovedV = PHI->getIncomingValue(RemovedIdx); |
| RemovedBB = PHI->getIncomingBlock(RemovedIdx); |
| } |
| |
| void PHIRemoveIncoming::revert(Tracker &Tracker) { |
| // Special case: if the PHI is now empty, as we don't need to care about the |
| // order of the incoming values. |
| unsigned NumIncoming = PHI->getNumIncomingValues(); |
| if (NumIncoming == 0) { |
| PHI->addIncoming(RemovedV, RemovedBB); |
| return; |
| } |
| // Shift all incoming values by one starting from the end until `Idx`. |
| // Start by adding a copy of the last incoming values. |
| unsigned LastIdx = NumIncoming - 1; |
| PHI->addIncoming(PHI->getIncomingValue(LastIdx), |
| PHI->getIncomingBlock(LastIdx)); |
| for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) { |
| auto *PrevV = PHI->getIncomingValue(Idx - 1); |
| auto *PrevBB = PHI->getIncomingBlock(Idx - 1); |
| PHI->setIncomingValue(Idx, PrevV); |
| PHI->setIncomingBlock(Idx, PrevBB); |
| } |
| PHI->setIncomingValue(RemovedIdx, RemovedV); |
| PHI->setIncomingBlock(RemovedIdx, RemovedBB); |
| } |
| |
| #ifndef NDEBUG |
| void PHIRemoveIncoming::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| PHIAddIncoming::PHIAddIncoming(PHINode *PHI) |
| : PHI(PHI), Idx(PHI->getNumIncomingValues()) {} |
| |
| void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); } |
| |
| #ifndef NDEBUG |
| void PHIAddIncoming::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| Tracker::~Tracker() { |
| assert(Changes.empty() && "You must accept or revert changes!"); |
| } |
| |
| EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr) |
| : ErasedIPtr(std::move(ErasedIPtr)) { |
| auto *I = cast<Instruction>(this->ErasedIPtr.get()); |
| auto LLVMInstrs = I->getLLVMInstrs(); |
| // Iterate in reverse program order. |
| for (auto *LLVMI : reverse(LLVMInstrs)) { |
| SmallVector<llvm::Value *> Operands; |
| Operands.reserve(LLVMI->getNumOperands()); |
| for (auto [OpNum, Use] : enumerate(LLVMI->operands())) |
| Operands.push_back(Use.get()); |
| InstrData.push_back({Operands, LLVMI}); |
| } |
| assert(is_sorted(InstrData, |
| [](const auto &D0, const auto &D1) { |
| return D0.LLVMI->comesBefore(D1.LLVMI); |
| }) && |
| "Expected reverse program order!"); |
| auto *BotLLVMI = cast<llvm::Instruction>(I->Val); |
| if (BotLLVMI->getNextNode() != nullptr) |
| NextLLVMIOrBB = BotLLVMI->getNextNode(); |
| else |
| NextLLVMIOrBB = BotLLVMI->getParent(); |
| } |
| |
| void EraseFromParent::accept() { |
| for (const auto &IData : InstrData) |
| IData.LLVMI->deleteValue(); |
| } |
| |
| void EraseFromParent::revert(Tracker &Tracker) { |
| // Place the bottom-most instruction first. |
| auto [Operands, BotLLVMI] = InstrData[0]; |
| if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) { |
| BotLLVMI->insertBefore(NextLLVMI->getIterator()); |
| } else { |
| auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB); |
| BotLLVMI->insertInto(LLVMBB, LLVMBB->end()); |
| } |
| for (auto [OpNum, Op] : enumerate(Operands)) |
| BotLLVMI->setOperand(OpNum, Op); |
| |
| // Go over the rest of the instructions and stack them on top. |
| for (auto [Operands, LLVMI] : drop_begin(InstrData)) { |
| LLVMI->insertBefore(BotLLVMI->getIterator()); |
| for (auto [OpNum, Op] : enumerate(Operands)) |
| LLVMI->setOperand(OpNum, Op); |
| BotLLVMI = LLVMI; |
| } |
| Tracker.getContext().registerValue(std::move(ErasedIPtr)); |
| } |
| |
| #ifndef NDEBUG |
| void EraseFromParent::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) { |
| if (auto *NextI = RemovedI->getNextNode()) |
| NextInstrOrBB = NextI; |
| else |
| NextInstrOrBB = RemovedI->getParent(); |
| } |
| |
| void RemoveFromParent::revert(Tracker &Tracker) { |
| if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { |
| RemovedI->insertBefore(NextI); |
| } else { |
| auto *BB = cast<BasicBlock *>(NextInstrOrBB); |
| RemovedI->insertInto(BB, BB->end()); |
| } |
| } |
| |
| #ifndef NDEBUG |
| void RemoveFromParent::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI) |
| : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {} |
| |
| void CatchSwitchAddHandler::revert(Tracker &Tracker) { |
| // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler() |
| // once it gets implemented. |
| auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val); |
| LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx); |
| } |
| |
| SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { |
| for (const auto &C : Switch->cases()) |
| Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()}); |
| } |
| |
| void SwitchRemoveCase::revert(Tracker &Tracker) { |
| // SwitchInst::removeCase doesn't provide any guarantees about the order of |
| // cases after removal. In order to preserve the original ordering, we save |
| // all of them and, when reverting, clear them all then insert them in the |
| // desired order. This still relies on the fact that `addCase` will insert |
| // them at the end, but it is documented to invalidate `case_end()` so it's |
| // probably okay. |
| unsigned NumCases = Switch->getNumCases(); |
| for (unsigned I = 0; I < NumCases; ++I) |
| Switch->removeCase(Switch->case_begin()); |
| for (auto &Case : Cases) |
| Switch->addCase(Case.Val, Case.Dest); |
| } |
| |
| #ifndef NDEBUG |
| void SwitchRemoveCase::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| void SwitchAddCase::revert(Tracker &Tracker) { |
| auto It = Switch->findCaseValue(Val); |
| Switch->removeCase(It); |
| } |
| |
| #ifndef NDEBUG |
| void SwitchAddCase::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |
| |
| MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) { |
| if (auto *NextI = MovedI->getNextNode()) |
| NextInstrOrBB = NextI; |
| else |
| NextInstrOrBB = MovedI->getParent(); |
| } |
| |
| void MoveInstr::revert(Tracker &Tracker) { |
| if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { |
| MovedI->moveBefore(NextI); |
| } else { |
| auto *BB = cast<BasicBlock *>(NextInstrOrBB); |
| MovedI->moveBefore(*BB, BB->end()); |
| } |
| } |
| |
| #ifndef NDEBUG |
| void MoveInstr::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); } |
| |
| InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {} |
| |
| #ifndef NDEBUG |
| void InsertIntoBB::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); } |
| |
| #ifndef NDEBUG |
| void CreateAndInsertInst::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI) |
| : SVI(SVI), PrevMask(SVI->getShuffleMask()) {} |
| |
| void ShuffleVectorSetMask::revert(Tracker &Tracker) { |
| SVI->setShuffleMask(PrevMask); |
| } |
| |
| #ifndef NDEBUG |
| void ShuffleVectorSetMask::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {} |
| |
| void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); } |
| #ifndef NDEBUG |
| void CmpSwapOperands::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif |
| |
| void Tracker::save() { |
| State = TrackerState::Record; |
| #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
| SnapshotChecker.save(); |
| #endif |
| } |
| |
| void Tracker::revert() { |
| assert(State == TrackerState::Record && "Forgot to save()!"); |
| State = TrackerState::Reverting; |
| for (auto &Change : reverse(Changes)) |
| Change->revert(*this); |
| Changes.clear(); |
| #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
| SnapshotChecker.expectNoDiff(); |
| #endif |
| State = TrackerState::Disabled; |
| } |
| |
| void Tracker::accept() { |
| assert(State == TrackerState::Record && "Forgot to save()!"); |
| State = TrackerState::Disabled; |
| for (auto &Change : Changes) |
| Change->accept(); |
| Changes.clear(); |
| } |
| |
| #ifndef NDEBUG |
| void Tracker::dump(raw_ostream &OS) const { |
| for (auto [Idx, ChangePtr] : enumerate(Changes)) { |
| OS << Idx << ". "; |
| ChangePtr->dump(OS); |
| OS << "\n"; |
| } |
| } |
| void Tracker::dump() const { |
| dump(dbgs()); |
| dbgs() << "\n"; |
| } |
| #endif // NDEBUG |