blob: d11af1e10e38fa18f990dbea1490386852e09408 [file] [log] [blame]
//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Scalar/LoopTermFold.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
#include "llvm/Analysis/MemorySSA.h"
#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include <cassert>
#include <optional>
using namespace llvm;
#define DEBUG_TYPE "loop-term-fold"
STATISTIC(NumTermFold,
"Number of terminating condition fold recognized and performed");
static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
const LoopInfo &LI, const TargetTransformInfo &TTI) {
if (!L->isInnermost()) {
LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
return std::nullopt;
}
// Only inspect on simple loop structure
if (!L->isLoopSimplifyForm()) {
LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
return std::nullopt;
}
if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
return std::nullopt;
}
BasicBlock *LoopLatch = L->getLoopLatch();
BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
if (!BI || BI->isUnconditional())
return std::nullopt;
auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
if (!TermCond) {
LLVM_DEBUG(
dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
return std::nullopt;
}
if (!TermCond->hasOneUse()) {
LLVM_DEBUG(
dbgs()
<< "Cannot replace terminating condition with more than one use\n");
return std::nullopt;
}
BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
Value *RHS = TermCond->getOperand(1);
if (!LHS || !L->isLoopInvariant(RHS))
// We could pattern match the inverse form of the icmp, but that is
// non-canonical, and this pass is running *very* late in the pipeline.
return std::nullopt;
// Find the IV used by the current exit condition.
PHINode *ToFold;
Value *ToFoldStart, *ToFoldStep;
if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
return std::nullopt;
// Ensure the simple recurrence is a part of the current loop.
if (ToFold->getParent() != L->getHeader())
return std::nullopt;
// If that IV isn't dead after we rewrite the exit condition in terms of
// another IV, there's no point in doing the transform.
if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
return std::nullopt;
// Inserting instructions in the preheader has a runtime cost, scale
// the allowed cost with the loops trip count as best we can.
const unsigned ExpansionBudget = [&]() {
unsigned Budget = 2 * SCEVCheapExpansionBudget;
if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
return std::min(Budget, SmallTC);
if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
return std::min(Budget, *SmallTC);
// Unknown trip count, assume long running by default.
return Budget;
}();
const SCEV *BECount = SE.getBackedgeTakenCount(L);
const DataLayout &DL = L->getHeader()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
PHINode *ToHelpFold = nullptr;
const SCEV *TermValueS = nullptr;
bool MustDropPoison = false;
auto InsertPt = L->getLoopPreheader()->getTerminator();
for (PHINode &PN : L->getHeader()->phis()) {
if (ToFold == &PN)
continue;
if (!SE.isSCEVable(PN.getType())) {
LLVM_DEBUG(dbgs() << "IV of phi '" << PN
<< "' is not SCEV-able, not qualified for the "
"terminating condition folding.\n");
continue;
}
const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
// Only speculate on affine AddRec
if (!AddRec || !AddRec->isAffine()) {
LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
<< "' is not an affine add recursion, not qualified "
"for the terminating condition folding.\n");
continue;
}
// Check that we can compute the value of AddRec on the exiting iteration
// without soundness problems. evaluateAtIteration internally needs
// to multiply the stride of the iteration number - which may wrap around.
// The issue here is subtle because computing the result accounting for
// wrap is insufficient. In order to use the result in an exit test, we
// must also know that AddRec doesn't take the same value on any previous
// iteration. The simplest case to consider is a candidate IV which is
// narrower than the trip count (and thus original IV), but this can
// also happen due to non-unit strides on the candidate IVs.
if (!AddRec->hasNoSelfWrap() ||
!SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
continue;
const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
if (!Expander.isSafeToExpand(TermValueSLocal)) {
LLVM_DEBUG(
dbgs() << "Is not safe to expand terminating value for phi node" << PN
<< "\n");
continue;
}
if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
InsertPt)) {
LLVM_DEBUG(
dbgs() << "Is too expensive to expand terminating value for phi node"
<< PN << "\n");
continue;
}
// The candidate IV may have been otherwise dead and poison from the
// very first iteration. If we can't disprove that, we can't use the IV.
if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
continue;
}
// The candidate IV may become poison on the last iteration. If this
// value is not branched on, this is a well defined program. We're
// about to add a new use to this IV, and we have to ensure we don't
// insert UB which didn't previously exist.
bool MustDropPoisonLocal = false;
Instruction *PostIncV =
cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
&DT)) {
LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
<< "\n");
// If this is a complex recurrance with multiple instructions computing
// the backedge value, we might need to strip poison flags from all of
// them.
if (PostIncV->getOperand(0) != &PN)
continue;
// In order to perform the transform, we need to drop the poison
// generating flags on this instruction (if any).
MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
}
// We pick the last legal alternate IV. We could expore choosing an optimal
// alternate IV if we had a decent heuristic to do so.
ToHelpFold = &PN;
TermValueS = TermValueSLocal;
MustDropPoison = MustDropPoisonLocal;
}
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
<< "Cannot find other AddRec IV to help folding\n";);
LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
<< "\nFound loop that can fold terminating condition\n"
<< " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
<< " TermCond: " << *TermCond << "\n"
<< " BrandInst: " << *BI << "\n"
<< " ToFold: " << *ToFold << "\n"
<< " ToHelpFold: " << *ToHelpFold << "\n");
if (!ToFold || !ToHelpFold)
return std::nullopt;
return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
}
static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
LoopInfo &LI, const TargetTransformInfo &TTI,
TargetLibraryInfo &TLI, MemorySSA *MSSA) {
std::unique_ptr<MemorySSAUpdater> MSSAU;
if (MSSA)
MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
if (!Opt)
return false;
auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
NumTermFold++;
BasicBlock *LoopPreheader = L->getLoopPreheader();
BasicBlock *LoopLatch = L->getLoopLatch();
(void)ToFold;
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
<< *ToFold << "\n"
<< "New term-cond phi-node:\n"
<< *ToHelpFold << "\n");
Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
(void)StartValue;
Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
// See comment in canFoldTermCondOfLoop on why this is sufficient.
if (MustDrop)
cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
// SCEVExpander for both use in preheader and latch
const DataLayout &DL = L->getHeader()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
assert(Expander.isSafeToExpand(TermValueS) &&
"Terminating value was checked safe in canFoldTerminatingCondition");
// Create new terminating value at loop preheader
Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
LoopPreheader->getTerminator());
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
<< *StartValue << "\n"
<< "Terminating value of new term-cond phi-node:\n"
<< *TermValue << "\n");
// Create new terminating condition at loop latch
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
Value *NewTermCond =
LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
"lsr_fold_term_cond.replaced_term_cond");
// Swap successors to exit loop body if IV equals to new TermValue
if (BI->getSuccessor(0) == L->getHeader())
BI->swapSuccessors();
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
<< *OldTermCond << "\n"
<< "New term-cond:\n"
<< *NewTermCond << "\n");
BI->setCondition(NewTermCond);
Expander.clear();
OldTermCond->eraseFromParent();
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
return true;
}
namespace {
class LoopTermFold : public LoopPass {
public:
static char ID; // Pass ID, replacement for typeid
LoopTermFold();
private:
bool runOnLoop(Loop *L, LPPassManager &LPM) override;
void getAnalysisUsage(AnalysisUsage &AU) const override;
};
} // end anonymous namespace
LoopTermFold::LoopTermFold() : LoopPass(ID) {
initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
}
void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<LoopInfoWrapperPass>();
AU.addPreserved<LoopInfoWrapperPass>();
AU.addPreservedID(LoopSimplifyID);
AU.addRequiredID(LoopSimplifyID);
AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addRequired<ScalarEvolutionWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<MemorySSAWrapperPass>();
}
bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
if (skipLoop(L))
return false;
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
*L->getHeader()->getParent());
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
*L->getHeader()->getParent());
auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
MemorySSA *MSSA = nullptr;
if (MSSAAnalysis)
MSSA = &MSSAAnalysis->getMSSA();
return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
}
PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &) {
if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
return PreservedAnalyses::all();
auto PA = getLoopPassPreservedAnalyses();
if (AR.MSSA)
PA.preserve<MemorySSAAnalysis>();
return PA;
}
char LoopTermFold::ID = 0;
INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
false, false)
Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }