|  | //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "llvm/Transforms/Scalar/LoopTermFold.h" | 
|  | #include "llvm/ADT/Statistic.h" | 
|  | #include "llvm/Analysis/LoopAnalysisManager.h" | 
|  | #include "llvm/Analysis/LoopInfo.h" | 
|  | #include "llvm/Analysis/LoopPass.h" | 
|  | #include "llvm/Analysis/MemorySSA.h" | 
|  | #include "llvm/Analysis/MemorySSAUpdater.h" | 
|  | #include "llvm/Analysis/ScalarEvolution.h" | 
|  | #include "llvm/Analysis/ScalarEvolutionExpressions.h" | 
|  | #include "llvm/Analysis/TargetLibraryInfo.h" | 
|  | #include "llvm/Analysis/TargetTransformInfo.h" | 
|  | #include "llvm/Analysis/ValueTracking.h" | 
|  | #include "llvm/Config/llvm-config.h" | 
|  | #include "llvm/IR/BasicBlock.h" | 
|  | #include "llvm/IR/Dominators.h" | 
|  | #include "llvm/IR/IRBuilder.h" | 
|  | #include "llvm/IR/InstrTypes.h" | 
|  | #include "llvm/IR/Instruction.h" | 
|  | #include "llvm/IR/Instructions.h" | 
|  | #include "llvm/IR/Type.h" | 
|  | #include "llvm/IR/Value.h" | 
|  | #include "llvm/InitializePasses.h" | 
|  | #include "llvm/Pass.h" | 
|  | #include "llvm/Support/Debug.h" | 
|  | #include "llvm/Support/raw_ostream.h" | 
|  | #include "llvm/Transforms/Scalar.h" | 
|  | #include "llvm/Transforms/Utils.h" | 
|  | #include "llvm/Transforms/Utils/BasicBlockUtils.h" | 
|  | #include "llvm/Transforms/Utils/Local.h" | 
|  | #include "llvm/Transforms/Utils/LoopUtils.h" | 
|  | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" | 
|  | #include <cassert> | 
|  | #include <optional> | 
|  |  | 
|  | using namespace llvm; | 
|  |  | 
|  | #define DEBUG_TYPE "loop-term-fold" | 
|  |  | 
|  | STATISTIC(NumTermFold, | 
|  | "Number of terminating condition fold recognized and performed"); | 
|  |  | 
|  | static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> | 
|  | canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, | 
|  | const LoopInfo &LI, const TargetTransformInfo &TTI) { | 
|  | if (!L->isInnermost()) { | 
|  | LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); | 
|  | return std::nullopt; | 
|  | } | 
|  | // Only inspect on simple loop structure | 
|  | if (!L->isLoopSimplifyForm()) { | 
|  | LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); | 
|  | return std::nullopt; | 
|  | } | 
|  |  | 
|  | if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { | 
|  | LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); | 
|  | return std::nullopt; | 
|  | } | 
|  |  | 
|  | BasicBlock *LoopLatch = L->getLoopLatch(); | 
|  | BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); | 
|  | if (!BI || BI->isUnconditional()) | 
|  | return std::nullopt; | 
|  | auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition()); | 
|  | if (!TermCond) { | 
|  | LLVM_DEBUG( | 
|  | dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); | 
|  | return std::nullopt; | 
|  | } | 
|  | if (!TermCond->hasOneUse()) { | 
|  | LLVM_DEBUG( | 
|  | dbgs() | 
|  | << "Cannot replace terminating condition with more than one use\n"); | 
|  | return std::nullopt; | 
|  | } | 
|  |  | 
|  | BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0)); | 
|  | Value *RHS = TermCond->getOperand(1); | 
|  | if (!LHS || !L->isLoopInvariant(RHS)) | 
|  | // We could pattern match the inverse form of the icmp, but that is | 
|  | // non-canonical, and this pass is running *very* late in the pipeline. | 
|  | return std::nullopt; | 
|  |  | 
|  | // Find the IV used by the current exit condition. | 
|  | PHINode *ToFold; | 
|  | Value *ToFoldStart, *ToFoldStep; | 
|  | if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) | 
|  | return std::nullopt; | 
|  |  | 
|  | // Ensure the simple recurrence is a part of the current loop. | 
|  | if (ToFold->getParent() != L->getHeader()) | 
|  | return std::nullopt; | 
|  |  | 
|  | // If that IV isn't dead after we rewrite the exit condition in terms of | 
|  | // another IV, there's no point in doing the transform. | 
|  | if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) | 
|  | return std::nullopt; | 
|  |  | 
|  | // Inserting instructions in the preheader has a runtime cost, scale | 
|  | // the allowed cost with the loops trip count as best we can. | 
|  | const unsigned ExpansionBudget = [&]() { | 
|  | unsigned Budget = 2 * SCEVCheapExpansionBudget; | 
|  | if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) | 
|  | return std::min(Budget, SmallTC); | 
|  | if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L)) | 
|  | return std::min(Budget, *SmallTC); | 
|  | // Unknown trip count, assume long running by default. | 
|  | return Budget; | 
|  | }(); | 
|  |  | 
|  | const SCEV *BECount = SE.getBackedgeTakenCount(L); | 
|  | const DataLayout &DL = L->getHeader()->getDataLayout(); | 
|  | SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); | 
|  |  | 
|  | PHINode *ToHelpFold = nullptr; | 
|  | const SCEV *TermValueS = nullptr; | 
|  | bool MustDropPoison = false; | 
|  | auto InsertPt = L->getLoopPreheader()->getTerminator(); | 
|  | for (PHINode &PN : L->getHeader()->phis()) { | 
|  | if (ToFold == &PN) | 
|  | continue; | 
|  |  | 
|  | if (!SE.isSCEVable(PN.getType())) { | 
|  | LLVM_DEBUG(dbgs() << "IV of phi '" << PN | 
|  | << "' is not SCEV-able, not qualified for the " | 
|  | "terminating condition folding.\n"); | 
|  | continue; | 
|  | } | 
|  | const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); | 
|  | // Only speculate on affine AddRec | 
|  | if (!AddRec || !AddRec->isAffine()) { | 
|  | LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN | 
|  | << "' is not an affine add recursion, not qualified " | 
|  | "for the terminating condition folding.\n"); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // Check that we can compute the value of AddRec on the exiting iteration | 
|  | // without soundness problems.  evaluateAtIteration internally needs | 
|  | // to multiply the stride of the iteration number - which may wrap around. | 
|  | // The issue here is subtle because computing the result accounting for | 
|  | // wrap is insufficient. In order to use the result in an exit test, we | 
|  | // must also know that AddRec doesn't take the same value on any previous | 
|  | // iteration. The simplest case to consider is a candidate IV which is | 
|  | // narrower than the trip count (and thus original IV), but this can | 
|  | // also happen due to non-unit strides on the candidate IVs. | 
|  | if (!AddRec->hasNoSelfWrap() || | 
|  | !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) | 
|  | continue; | 
|  |  | 
|  | const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); | 
|  | const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); | 
|  | if (!Expander.isSafeToExpand(TermValueSLocal)) { | 
|  | LLVM_DEBUG( | 
|  | dbgs() << "Is not safe to expand terminating value for phi node" << PN | 
|  | << "\n"); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI, | 
|  | InsertPt)) { | 
|  | LLVM_DEBUG( | 
|  | dbgs() << "Is too expensive to expand terminating value for phi node" | 
|  | << PN << "\n"); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // The candidate IV may have been otherwise dead and poison from the | 
|  | // very first iteration.  If we can't disprove that, we can't use the IV. | 
|  | if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { | 
|  | LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n"); | 
|  | continue; | 
|  | } | 
|  |  | 
|  | // The candidate IV may become poison on the last iteration.  If this | 
|  | // value is not branched on, this is a well defined program.  We're | 
|  | // about to add a new use to this IV, and we have to ensure we don't | 
|  | // insert UB which didn't previously exist. | 
|  | bool MustDropPoisonLocal = false; | 
|  | Instruction *PostIncV = | 
|  | cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch)); | 
|  | if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), | 
|  | &DT)) { | 
|  | LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN | 
|  | << "\n"); | 
|  |  | 
|  | // If this is a complex recurrance with multiple instructions computing | 
|  | // the backedge value, we might need to strip poison flags from all of | 
|  | // them. | 
|  | if (PostIncV->getOperand(0) != &PN) | 
|  | continue; | 
|  |  | 
|  | // In order to perform the transform, we need to drop the poison | 
|  | // generating flags on this instruction (if any). | 
|  | MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); | 
|  | } | 
|  |  | 
|  | // We pick the last legal alternate IV.  We could expore choosing an optimal | 
|  | // alternate IV if we had a decent heuristic to do so. | 
|  | ToHelpFold = &PN; | 
|  | TermValueS = TermValueSLocal; | 
|  | MustDropPoison = MustDropPoisonLocal; | 
|  | } | 
|  |  | 
|  | LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() | 
|  | << "Cannot find other AddRec IV to help folding\n";); | 
|  |  | 
|  | LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() | 
|  | << "\nFound loop that can fold terminating condition\n" | 
|  | << "  BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" | 
|  | << "  TermCond: " << *TermCond << "\n" | 
|  | << "  BrandInst: " << *BI << "\n" | 
|  | << "  ToFold: " << *ToFold << "\n" | 
|  | << "  ToHelpFold: " << *ToHelpFold << "\n"); | 
|  |  | 
|  | if (!ToFold || !ToHelpFold) | 
|  | return std::nullopt; | 
|  | return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); | 
|  | } | 
|  |  | 
|  | static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, | 
|  | LoopInfo &LI, const TargetTransformInfo &TTI, | 
|  | TargetLibraryInfo &TLI, MemorySSA *MSSA) { | 
|  | std::unique_ptr<MemorySSAUpdater> MSSAU; | 
|  | if (MSSA) | 
|  | MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); | 
|  |  | 
|  | auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); | 
|  | if (!Opt) | 
|  | return false; | 
|  |  | 
|  | auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; | 
|  |  | 
|  | NumTermFold++; | 
|  |  | 
|  | BasicBlock *LoopPreheader = L->getLoopPreheader(); | 
|  | BasicBlock *LoopLatch = L->getLoopLatch(); | 
|  |  | 
|  | (void)ToFold; | 
|  | LLVM_DEBUG(dbgs() << "To fold phi-node:\n" | 
|  | << *ToFold << "\n" | 
|  | << "New term-cond phi-node:\n" | 
|  | << *ToHelpFold << "\n"); | 
|  |  | 
|  | Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); | 
|  | (void)StartValue; | 
|  | Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); | 
|  |  | 
|  | // See comment in canFoldTermCondOfLoop on why this is sufficient. | 
|  | if (MustDrop) | 
|  | cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); | 
|  |  | 
|  | // SCEVExpander for both use in preheader and latch | 
|  | const DataLayout &DL = L->getHeader()->getDataLayout(); | 
|  | SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); | 
|  |  | 
|  | assert(Expander.isSafeToExpand(TermValueS) && | 
|  | "Terminating value was checked safe in canFoldTerminatingCondition"); | 
|  |  | 
|  | // Create new terminating value at loop preheader | 
|  | Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), | 
|  | LoopPreheader->getTerminator()); | 
|  |  | 
|  | LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" | 
|  | << *StartValue << "\n" | 
|  | << "Terminating value of new term-cond phi-node:\n" | 
|  | << *TermValue << "\n"); | 
|  |  | 
|  | // Create new terminating condition at loop latch | 
|  | BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); | 
|  | ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); | 
|  | IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); | 
|  | Value *NewTermCond = | 
|  | LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, | 
|  | "lsr_fold_term_cond.replaced_term_cond"); | 
|  | // Swap successors to exit loop body if IV equals to new TermValue | 
|  | if (BI->getSuccessor(0) == L->getHeader()) | 
|  | BI->swapSuccessors(); | 
|  |  | 
|  | LLVM_DEBUG(dbgs() << "Old term-cond:\n" | 
|  | << *OldTermCond << "\n" | 
|  | << "New term-cond:\n" | 
|  | << *NewTermCond << "\n"); | 
|  |  | 
|  | BI->setCondition(NewTermCond); | 
|  |  | 
|  | Expander.clear(); | 
|  | OldTermCond->eraseFromParent(); | 
|  | DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); | 
|  | return true; | 
|  | } | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | class LoopTermFold : public LoopPass { | 
|  | public: | 
|  | static char ID; // Pass ID, replacement for typeid | 
|  |  | 
|  | LoopTermFold(); | 
|  |  | 
|  | private: | 
|  | bool runOnLoop(Loop *L, LPPassManager &LPM) override; | 
|  | void getAnalysisUsage(AnalysisUsage &AU) const override; | 
|  | }; | 
|  |  | 
|  | } // end anonymous namespace | 
|  |  | 
|  | LoopTermFold::LoopTermFold() : LoopPass(ID) { | 
|  | initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); | 
|  | } | 
|  |  | 
|  | void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { | 
|  | AU.addRequired<LoopInfoWrapperPass>(); | 
|  | AU.addPreserved<LoopInfoWrapperPass>(); | 
|  | AU.addPreservedID(LoopSimplifyID); | 
|  | AU.addRequiredID(LoopSimplifyID); | 
|  | AU.addRequired<DominatorTreeWrapperPass>(); | 
|  | AU.addPreserved<DominatorTreeWrapperPass>(); | 
|  | AU.addRequired<ScalarEvolutionWrapperPass>(); | 
|  | AU.addPreserved<ScalarEvolutionWrapperPass>(); | 
|  | AU.addRequired<TargetLibraryInfoWrapperPass>(); | 
|  | AU.addRequired<TargetTransformInfoWrapperPass>(); | 
|  | AU.addPreserved<MemorySSAWrapperPass>(); | 
|  | } | 
|  |  | 
|  | bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { | 
|  | if (skipLoop(L)) | 
|  | return false; | 
|  |  | 
|  | auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); | 
|  | auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); | 
|  | auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); | 
|  | const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( | 
|  | *L->getHeader()->getParent()); | 
|  | auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( | 
|  | *L->getHeader()->getParent()); | 
|  | auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); | 
|  | MemorySSA *MSSA = nullptr; | 
|  | if (MSSAAnalysis) | 
|  | MSSA = &MSSAAnalysis->getMSSA(); | 
|  | return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); | 
|  | } | 
|  |  | 
|  | PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, | 
|  | LoopStandardAnalysisResults &AR, | 
|  | LPMUpdater &) { | 
|  | if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA)) | 
|  | return PreservedAnalyses::all(); | 
|  |  | 
|  | auto PA = getLoopPassPreservedAnalyses(); | 
|  | if (AR.MSSA) | 
|  | PA.preserve<MemorySSAAnalysis>(); | 
|  | return PA; | 
|  | } | 
|  |  | 
|  | char LoopTermFold::ID = 0; | 
|  |  | 
|  | INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", | 
|  | false, false) | 
|  | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) | 
|  | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) | 
|  | INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) | 
|  | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) | 
|  | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) | 
|  | INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", | 
|  | false, false) | 
|  |  | 
|  | Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); } |