| //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| //===----------------------------------------------------------------------===// |
| |
| #include "llvm/Transforms/Scalar/LoopTermFold.h" |
| #include "llvm/ADT/Statistic.h" |
| #include "llvm/Analysis/LoopAnalysisManager.h" |
| #include "llvm/Analysis/LoopInfo.h" |
| #include "llvm/Analysis/LoopPass.h" |
| #include "llvm/Analysis/MemorySSA.h" |
| #include "llvm/Analysis/MemorySSAUpdater.h" |
| #include "llvm/Analysis/ScalarEvolution.h" |
| #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
| #include "llvm/Analysis/TargetLibraryInfo.h" |
| #include "llvm/Analysis/TargetTransformInfo.h" |
| #include "llvm/Analysis/ValueTracking.h" |
| #include "llvm/Config/llvm-config.h" |
| #include "llvm/IR/BasicBlock.h" |
| #include "llvm/IR/Dominators.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/InstrTypes.h" |
| #include "llvm/IR/Instruction.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/Type.h" |
| #include "llvm/IR/Value.h" |
| #include "llvm/InitializePasses.h" |
| #include "llvm/Pass.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include "llvm/Transforms/Scalar.h" |
| #include "llvm/Transforms/Utils.h" |
| #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
| #include "llvm/Transforms/Utils/Local.h" |
| #include "llvm/Transforms/Utils/LoopUtils.h" |
| #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
| #include <cassert> |
| #include <optional> |
| |
| using namespace llvm; |
| |
| #define DEBUG_TYPE "loop-term-fold" |
| |
| STATISTIC(NumTermFold, |
| "Number of terminating condition fold recognized and performed"); |
| |
| static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> |
| canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, |
| const LoopInfo &LI, const TargetTransformInfo &TTI) { |
| if (!L->isInnermost()) { |
| LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); |
| return std::nullopt; |
| } |
| // Only inspect on simple loop structure |
| if (!L->isLoopSimplifyForm()) { |
| LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); |
| return std::nullopt; |
| } |
| |
| if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { |
| LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); |
| return std::nullopt; |
| } |
| |
| BasicBlock *LoopLatch = L->getLoopLatch(); |
| BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); |
| if (!BI || BI->isUnconditional()) |
| return std::nullopt; |
| auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition()); |
| if (!TermCond) { |
| LLVM_DEBUG( |
| dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); |
| return std::nullopt; |
| } |
| if (!TermCond->hasOneUse()) { |
| LLVM_DEBUG( |
| dbgs() |
| << "Cannot replace terminating condition with more than one use\n"); |
| return std::nullopt; |
| } |
| |
| BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0)); |
| Value *RHS = TermCond->getOperand(1); |
| if (!LHS || !L->isLoopInvariant(RHS)) |
| // We could pattern match the inverse form of the icmp, but that is |
| // non-canonical, and this pass is running *very* late in the pipeline. |
| return std::nullopt; |
| |
| // Find the IV used by the current exit condition. |
| PHINode *ToFold; |
| Value *ToFoldStart, *ToFoldStep; |
| if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) |
| return std::nullopt; |
| |
| // Ensure the simple recurrence is a part of the current loop. |
| if (ToFold->getParent() != L->getHeader()) |
| return std::nullopt; |
| |
| // If that IV isn't dead after we rewrite the exit condition in terms of |
| // another IV, there's no point in doing the transform. |
| if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) |
| return std::nullopt; |
| |
| // Inserting instructions in the preheader has a runtime cost, scale |
| // the allowed cost with the loops trip count as best we can. |
| const unsigned ExpansionBudget = [&]() { |
| unsigned Budget = 2 * SCEVCheapExpansionBudget; |
| if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) |
| return std::min(Budget, SmallTC); |
| if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L)) |
| return std::min(Budget, *SmallTC); |
| // Unknown trip count, assume long running by default. |
| return Budget; |
| }(); |
| |
| const SCEV *BECount = SE.getBackedgeTakenCount(L); |
| const DataLayout &DL = L->getHeader()->getDataLayout(); |
| SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); |
| |
| PHINode *ToHelpFold = nullptr; |
| const SCEV *TermValueS = nullptr; |
| bool MustDropPoison = false; |
| auto InsertPt = L->getLoopPreheader()->getTerminator(); |
| for (PHINode &PN : L->getHeader()->phis()) { |
| if (ToFold == &PN) |
| continue; |
| |
| if (!SE.isSCEVable(PN.getType())) { |
| LLVM_DEBUG(dbgs() << "IV of phi '" << PN |
| << "' is not SCEV-able, not qualified for the " |
| "terminating condition folding.\n"); |
| continue; |
| } |
| const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); |
| // Only speculate on affine AddRec |
| if (!AddRec || !AddRec->isAffine()) { |
| LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN |
| << "' is not an affine add recursion, not qualified " |
| "for the terminating condition folding.\n"); |
| continue; |
| } |
| |
| // Check that we can compute the value of AddRec on the exiting iteration |
| // without soundness problems. evaluateAtIteration internally needs |
| // to multiply the stride of the iteration number - which may wrap around. |
| // The issue here is subtle because computing the result accounting for |
| // wrap is insufficient. In order to use the result in an exit test, we |
| // must also know that AddRec doesn't take the same value on any previous |
| // iteration. The simplest case to consider is a candidate IV which is |
| // narrower than the trip count (and thus original IV), but this can |
| // also happen due to non-unit strides on the candidate IVs. |
| if (!AddRec->hasNoSelfWrap() || |
| !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) |
| continue; |
| |
| const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); |
| const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); |
| if (!Expander.isSafeToExpand(TermValueSLocal)) { |
| LLVM_DEBUG( |
| dbgs() << "Is not safe to expand terminating value for phi node" << PN |
| << "\n"); |
| continue; |
| } |
| |
| if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI, |
| InsertPt)) { |
| LLVM_DEBUG( |
| dbgs() << "Is too expensive to expand terminating value for phi node" |
| << PN << "\n"); |
| continue; |
| } |
| |
| // The candidate IV may have been otherwise dead and poison from the |
| // very first iteration. If we can't disprove that, we can't use the IV. |
| if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { |
| LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n"); |
| continue; |
| } |
| |
| // The candidate IV may become poison on the last iteration. If this |
| // value is not branched on, this is a well defined program. We're |
| // about to add a new use to this IV, and we have to ensure we don't |
| // insert UB which didn't previously exist. |
| bool MustDropPoisonLocal = false; |
| Instruction *PostIncV = |
| cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch)); |
| if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), |
| &DT)) { |
| LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN |
| << "\n"); |
| |
| // If this is a complex recurrance with multiple instructions computing |
| // the backedge value, we might need to strip poison flags from all of |
| // them. |
| if (PostIncV->getOperand(0) != &PN) |
| continue; |
| |
| // In order to perform the transform, we need to drop the poison |
| // generating flags on this instruction (if any). |
| MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); |
| } |
| |
| // We pick the last legal alternate IV. We could expore choosing an optimal |
| // alternate IV if we had a decent heuristic to do so. |
| ToHelpFold = &PN; |
| TermValueS = TermValueSLocal; |
| MustDropPoison = MustDropPoisonLocal; |
| } |
| |
| LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() |
| << "Cannot find other AddRec IV to help folding\n";); |
| |
| LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() |
| << "\nFound loop that can fold terminating condition\n" |
| << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" |
| << " TermCond: " << *TermCond << "\n" |
| << " BrandInst: " << *BI << "\n" |
| << " ToFold: " << *ToFold << "\n" |
| << " ToHelpFold: " << *ToHelpFold << "\n"); |
| |
| if (!ToFold || !ToHelpFold) |
| return std::nullopt; |
| return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); |
| } |
| |
| static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, |
| LoopInfo &LI, const TargetTransformInfo &TTI, |
| TargetLibraryInfo &TLI, MemorySSA *MSSA) { |
| std::unique_ptr<MemorySSAUpdater> MSSAU; |
| if (MSSA) |
| MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); |
| |
| auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); |
| if (!Opt) |
| return false; |
| |
| auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; |
| |
| NumTermFold++; |
| |
| BasicBlock *LoopPreheader = L->getLoopPreheader(); |
| BasicBlock *LoopLatch = L->getLoopLatch(); |
| |
| (void)ToFold; |
| LLVM_DEBUG(dbgs() << "To fold phi-node:\n" |
| << *ToFold << "\n" |
| << "New term-cond phi-node:\n" |
| << *ToHelpFold << "\n"); |
| |
| Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); |
| (void)StartValue; |
| Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); |
| |
| // See comment in canFoldTermCondOfLoop on why this is sufficient. |
| if (MustDrop) |
| cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); |
| |
| // SCEVExpander for both use in preheader and latch |
| const DataLayout &DL = L->getHeader()->getDataLayout(); |
| SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); |
| |
| assert(Expander.isSafeToExpand(TermValueS) && |
| "Terminating value was checked safe in canFoldTerminatingCondition"); |
| |
| // Create new terminating value at loop preheader |
| Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), |
| LoopPreheader->getTerminator()); |
| |
| LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" |
| << *StartValue << "\n" |
| << "Terminating value of new term-cond phi-node:\n" |
| << *TermValue << "\n"); |
| |
| // Create new terminating condition at loop latch |
| BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); |
| ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); |
| IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); |
| Value *NewTermCond = |
| LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, |
| "lsr_fold_term_cond.replaced_term_cond"); |
| // Swap successors to exit loop body if IV equals to new TermValue |
| if (BI->getSuccessor(0) == L->getHeader()) |
| BI->swapSuccessors(); |
| |
| LLVM_DEBUG(dbgs() << "Old term-cond:\n" |
| << *OldTermCond << "\n" |
| << "New term-cond:\n" |
| << *NewTermCond << "\n"); |
| |
| BI->setCondition(NewTermCond); |
| |
| Expander.clear(); |
| OldTermCond->eraseFromParent(); |
| DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); |
| return true; |
| } |
| |
| namespace { |
| |
| class LoopTermFold : public LoopPass { |
| public: |
| static char ID; // Pass ID, replacement for typeid |
| |
| LoopTermFold(); |
| |
| private: |
| bool runOnLoop(Loop *L, LPPassManager &LPM) override; |
| void getAnalysisUsage(AnalysisUsage &AU) const override; |
| }; |
| |
| } // end anonymous namespace |
| |
| LoopTermFold::LoopTermFold() : LoopPass(ID) { |
| initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); |
| } |
| |
| void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { |
| AU.addRequired<LoopInfoWrapperPass>(); |
| AU.addPreserved<LoopInfoWrapperPass>(); |
| AU.addPreservedID(LoopSimplifyID); |
| AU.addRequiredID(LoopSimplifyID); |
| AU.addRequired<DominatorTreeWrapperPass>(); |
| AU.addPreserved<DominatorTreeWrapperPass>(); |
| AU.addRequired<ScalarEvolutionWrapperPass>(); |
| AU.addPreserved<ScalarEvolutionWrapperPass>(); |
| AU.addRequired<TargetLibraryInfoWrapperPass>(); |
| AU.addRequired<TargetTransformInfoWrapperPass>(); |
| AU.addPreserved<MemorySSAWrapperPass>(); |
| } |
| |
| bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { |
| if (skipLoop(L)) |
| return false; |
| |
| auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
| auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
| auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
| const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( |
| *L->getHeader()->getParent()); |
| auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( |
| *L->getHeader()->getParent()); |
| auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); |
| MemorySSA *MSSA = nullptr; |
| if (MSSAAnalysis) |
| MSSA = &MSSAAnalysis->getMSSA(); |
| return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); |
| } |
| |
| PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, |
| LoopStandardAnalysisResults &AR, |
| LPMUpdater &) { |
| if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA)) |
| return PreservedAnalyses::all(); |
| |
| auto PA = getLoopPassPreservedAnalyses(); |
| if (AR.MSSA) |
| PA.preserve<MemorySSAAnalysis>(); |
| return PA; |
| } |
| |
| char LoopTermFold::ID = 0; |
| |
| INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", |
| false, false) |
| INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
| INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
| INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", |
| false, false) |
| |
| Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); } |