blob: 4ae2baca327c7afb935ce606013434371a9181f0 [file] [log] [blame]
#include "llvm/Transforms/Utils/LoopConstrainer.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/Dominators.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
using namespace llvm;
static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
#define DEBUG_TYPE "loop-constrainer"
/// Given a loop with an deccreasing induction variable, is it possible to
/// safely calculate the bounds of a new loop using the given Predicate.
static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
const SCEV *Step, ICmpInst::Predicate Pred,
unsigned LatchBrExitIdx, Loop *L,
ScalarEvolution &SE) {
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
return false;
if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
return false;
assert(SE.isKnownNegative(Step) && "expecting negative step");
LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
bool IsSigned = ICmpInst::isSigned(Pred);
// The predicate that we need to check that the induction variable lies
// within bounds.
ICmpInst::Predicate BoundPred =
IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
auto StartLG = SE.applyLoopGuards(Start, L);
auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
if (LatchBrExitIdx == 1)
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
: APInt::getMinValue(BitWidth);
const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
const SCEV *MinusOne =
SE.getMinusSCEV(BoundLG, SE.getOne(BoundLG->getType()));
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, MinusOne) &&
SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit);
}
/// Given a loop with an increasing induction variable, is it possible to
/// safely calculate the bounds of a new loop using the given Predicate.
static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
const SCEV *Step, ICmpInst::Predicate Pred,
unsigned LatchBrExitIdx, Loop *L,
ScalarEvolution &SE) {
if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
return false;
if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
return false;
LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
bool IsSigned = ICmpInst::isSigned(Pred);
// The predicate that we need to check that the induction variable lies
// within bounds.
ICmpInst::Predicate BoundPred =
IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
auto StartLG = SE.applyLoopGuards(Start, L);
auto BoundLG = SE.applyLoopGuards(BoundSCEV, L);
if (LatchBrExitIdx == 1)
return SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG, BoundLG);
assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
: APInt::getMaxValue(BitWidth);
const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
return (SE.isLoopEntryGuardedByCond(L, BoundPred, StartLG,
SE.getAddExpr(BoundLG, Step)) &&
SE.isLoopEntryGuardedByCond(L, BoundPred, BoundLG, Limit));
}
/// Returns estimate for max latch taken count of the loop of the narrowest
/// available type. If the latch block has such estimate, it is returned.
/// Otherwise, we use max exit count of whole loop (that is potentially of wider
/// type than latch check itself), which is still better than no estimate.
static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
const Loop &L) {
const SCEV *FromBlock =
SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
if (isa<SCEVCouldNotCompute>(FromBlock))
return SE.getSymbolicMaxBackedgeTakenCount(&L);
return FromBlock;
}
std::optional<LoopStructure>
LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
bool AllowUnsignedLatchCond,
const char *&FailureReason) {
if (!L.isLoopSimplifyForm()) {
FailureReason = "loop not in LoopSimplify form";
return std::nullopt;
}
BasicBlock *Latch = L.getLoopLatch();
assert(Latch && "Simplified loops only have one latch!");
if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
FailureReason = "loop has already been cloned";
return std::nullopt;
}
if (!L.isLoopExiting(Latch)) {
FailureReason = "no loop latch";
return std::nullopt;
}
BasicBlock *Header = L.getHeader();
BasicBlock *Preheader = L.getLoopPreheader();
if (!Preheader) {
FailureReason = "no preheader";
return std::nullopt;
}
BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
if (!LatchBr || LatchBr->isUnconditional()) {
FailureReason = "latch terminator not conditional branch";
return std::nullopt;
}
unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
FailureReason = "latch terminator branch not conditional on integral icmp";
return std::nullopt;
}
const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
FailureReason = "could not compute latch count";
return std::nullopt;
}
assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
ScalarEvolution::LoopInvariant &&
"loop variant exit count doesn't make sense!");
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *LeftValue = ICI->getOperand(0);
const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
Value *RightValue = ICI->getOperand(1);
const SCEV *RightSCEV = SE.getSCEV(RightValue);
// We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
if (isa<SCEVAddRecExpr>(RightSCEV)) {
std::swap(LeftSCEV, RightSCEV);
std::swap(LeftValue, RightValue);
Pred = ICmpInst::getSwappedPredicate(Pred);
} else {
FailureReason = "no add recurrences in the icmp";
return std::nullopt;
}
}
auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
if (AR->getNoWrapFlags(SCEV::FlagNSW))
return true;
IntegerType *Ty = cast<IntegerType>(AR->getType());
IntegerType *WideTy =
IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
const SCEVAddRecExpr *ExtendAfterOp =
dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
if (ExtendAfterOp) {
const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
const SCEV *ExtendedStep =
SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
if (NoSignedWrap)
return true;
}
// We may have proved this when computing the sign extension above.
return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
};
// `ICI` is interpreted as taking the backedge if the *next* value of the
// induction variable satisfies some constraint.
const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
if (IndVarBase->getLoop() != &L) {
FailureReason = "LHS in cmp is not an AddRec for this loop";
return std::nullopt;
}
if (!IndVarBase->isAffine()) {
FailureReason = "LHS in icmp not induction variable";
return std::nullopt;
}
const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
if (!isa<SCEVConstant>(StepRec)) {
FailureReason = "LHS in icmp not induction variable";
return std::nullopt;
}
ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
FailureReason = "LHS in icmp needs nsw for equality predicates";
return std::nullopt;
}
assert(!StepCI->isZero() && "Zero step?");
bool IsIncreasing = !StepCI->isNegative();
bool IsSignedPredicate;
const SCEV *StartNext = IndVarBase->getStart();
const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
const SCEV *Step = SE.getSCEV(StepCI);
const SCEV *FixedRightSCEV = nullptr;
// If RightValue resides within loop (but still being loop invariant),
// regenerate it as preheader.
if (auto *I = dyn_cast<Instruction>(RightValue))
if (L.contains(I->getParent()))
FixedRightSCEV = RightSCEV;
if (IsIncreasing) {
bool DecreasedRightValueByOne = false;
if (StepCI->isOne()) {
// Try to turn eq/ne predicates to those we can work with.
if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
// while (++i != len) { while (++i < len) {
// ... ---> ...
// } }
// If both parts are known non-negative, it is profitable to use
// unsigned comparison in increasing loop. This allows us to make the
// comparison check against "RightSCEV + 1" more optimistic.
if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
isKnownNonNegativeInLoop(RightSCEV, &L, SE))
Pred = ICmpInst::ICMP_ULT;
else
Pred = ICmpInst::ICMP_SLT;
else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
// while (true) { while (true) {
// if (++i == len) ---> if (++i > len - 1)
// break; break;
// ... ...
// } }
if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
Pred = ICmpInst::ICMP_UGT;
RightSCEV =
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
DecreasedRightValueByOne = true;
} else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
Pred = ICmpInst::ICMP_SGT;
RightSCEV =
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
DecreasedRightValueByOne = true;
}
}
}
bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
bool FoundExpectedPred =
(LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
if (!FoundExpectedPred) {
FailureReason = "expected icmp slt semantically, found something else";
return std::nullopt;
}
IsSignedPredicate = ICmpInst::isSigned(Pred);
if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
FailureReason = "unsigned latch conditions are explicitly prohibited";
return std::nullopt;
}
if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
LatchBrExitIdx, &L, SE)) {
FailureReason = "Unsafe loop bounds";
return std::nullopt;
}
if (LatchBrExitIdx == 0) {
// We need to increase the right value unless we have already decreased
// it virtually when we replaced EQ with SGT.
if (!DecreasedRightValueByOne)
FixedRightSCEV =
SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
} else {
assert(!DecreasedRightValueByOne &&
"Right value can be decreased only for LatchBrExitIdx == 0!");
}
} else {
bool IncreasedRightValueByOne = false;
if (StepCI->isMinusOne()) {
// Try to turn eq/ne predicates to those we can work with.
if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
// while (--i != len) { while (--i > len) {
// ... ---> ...
// } }
// We intentionally don't turn the predicate into UGT even if we know
// that both operands are non-negative, because it will only pessimize
// our check against "RightSCEV - 1".
Pred = ICmpInst::ICMP_SGT;
else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
// while (true) { while (true) {
// if (--i == len) ---> if (--i < len + 1)
// break; break;
// ... ...
// } }
if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
Pred = ICmpInst::ICMP_ULT;
RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
IncreasedRightValueByOne = true;
} else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
Pred = ICmpInst::ICMP_SLT;
RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
IncreasedRightValueByOne = true;
}
}
}
bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
bool FoundExpectedPred =
(GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
if (!FoundExpectedPred) {
FailureReason = "expected icmp sgt semantically, found something else";
return std::nullopt;
}
IsSignedPredicate =
Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
FailureReason = "unsigned latch conditions are explicitly prohibited";
return std::nullopt;
}
if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
LatchBrExitIdx, &L, SE)) {
FailureReason = "Unsafe bounds";
return std::nullopt;
}
if (LatchBrExitIdx == 0) {
// We need to decrease the right value unless we have already increased
// it virtually when we replaced EQ with SLT.
if (!IncreasedRightValueByOne)
FixedRightSCEV =
SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
} else {
assert(!IncreasedRightValueByOne &&
"Right value can be increased only for LatchBrExitIdx == 0!");
}
}
BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
assert(!L.contains(LatchExit) && "expected an exit block!");
const DataLayout &DL = Preheader->getDataLayout();
SCEVExpander Expander(SE, DL, "loop-constrainer");
Instruction *Ins = Preheader->getTerminator();
if (FixedRightSCEV)
RightValue =
Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
IndVarStartV->setName("indvar.start");
LoopStructure Result;
Result.Tag = "main";
Result.Header = Header;
Result.Latch = Latch;
Result.LatchBr = LatchBr;
Result.LatchExit = LatchExit;
Result.LatchBrExitIdx = LatchBrExitIdx;
Result.IndVarStart = IndVarStartV;
Result.IndVarStep = StepCI;
Result.IndVarBase = LeftValue;
Result.IndVarIncreasing = IsIncreasing;
Result.LoopExitAt = RightValue;
Result.IsSignedPredicate = IsSignedPredicate;
Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
FailureReason = nullptr;
return Result;
}
// Add metadata to the loop L to disable loop optimizations. Callers need to
// confirm that optimizing loop L is not beneficial.
static void DisableAllLoopOptsOnLoop(Loop &L) {
// We do not care about any existing loopID related metadata for L, since we
// are setting all loop metadata to false.
LLVMContext &Context = L.getHeader()->getContext();
// Reserve first location for self reference to the LoopID metadata node.
MDNode *Dummy = MDNode::get(Context, {});
MDNode *DisableUnroll = MDNode::get(
Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
Metadata *FalseVal =
ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
MDNode *DisableVectorize = MDNode::get(
Context,
{MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
MDNode *DisableLICMVersioning = MDNode::get(
Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
MDNode *DisableDistribution = MDNode::get(
Context,
{MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
MDNode *NewLoopID =
MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
DisableLICMVersioning, DisableDistribution});
// Set operand 0 to refer to the loop id itself.
NewLoopID->replaceOperandWith(0, NewLoopID);
L.setLoopID(NewLoopID);
}
LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
function_ref<void(Loop *, bool)> LPMAddNewLoop,
const LoopStructure &LS, ScalarEvolution &SE,
DominatorTree &DT, Type *T, SubRanges SR)
: F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
MainLoopStructure(LS), SR(SR) {}
void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
const char *Tag) const {
for (BasicBlock *BB : OriginalLoop.getBlocks()) {
BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
Result.Blocks.push_back(Clone);
Result.Map[BB] = Clone;
}
auto GetClonedValue = [&Result](Value *V) {
assert(V && "null values not in domain!");
auto It = Result.Map.find(V);
if (It == Result.Map.end())
return V;
return static_cast<Value *>(It->second);
};
auto *ClonedLatch =
cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
MDNode::get(Ctx, {}));
Result.Structure = MainLoopStructure.map(GetClonedValue);
Result.Structure.Tag = Tag;
for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
BasicBlock *ClonedBB = Result.Blocks[i];
BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
for (Instruction &I : *ClonedBB)
RemapInstruction(&I, Result.Map,
RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
// Exit blocks will now have one more predecessor and their PHI nodes need
// to be edited to reflect that. No phi nodes need to be introduced because
// the loop is in LCSSA.
for (auto *SBB : successors(OriginalBB)) {
if (OriginalLoop.contains(SBB))
continue; // not an exit block
for (PHINode &PN : SBB->phis()) {
Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
SE.forgetValue(&PN);
}
}
}
}
LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
BasicBlock *ContinuationBlock) const {
// We start with a loop with a single latch:
//
// +--------------------+
// | |
// | preheader |
// | |
// +--------+-----------+
// | ----------------\
// | / |
// +--------v----v------+ |
// | | |
// | header | |
// | | |
// +--------------------+ |
// |
// ..... |
// |
// +--------------------+ |
// | | |
// | latch >----------/
// | |
// +-------v------------+
// |
// |
// | +--------------------+
// | | |
// +---> original exit |
// | |
// +--------------------+
//
// We change the control flow to look like
//
//
// +--------------------+
// | |
// | preheader >-------------------------+
// | | |
// +--------v-----------+ |
// | /-------------+ |
// | / | |
// +--------v--v--------+ | |
// | | | |
// | header | | +--------+ |
// | | | | | |
// +--------------------+ | | +-----v-----v-----------+
// | | | |
// | | | .pseudo.exit |
// | | | |
// | | +-----------v-----------+
// | | |
// ..... | | |
// | | +--------v-------------+
// +--------------------+ | | | |
// | | | | | ContinuationBlock |
// | latch >------+ | | |
// | | | +----------------------+
// +---------v----------+ |
// | |
// | |
// | +---------------^-----+
// | | |
// +-----> .exit.selector |
// | |
// +----------v----------+
// |
// +--------------------+ |
// | | |
// | original exit <----+
// | |
// +--------------------+
RewrittenRangeInfo RRI;
BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
&F, BBInsertLocation);
RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
BBInsertLocation);
BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
bool Increasing = LS.IndVarIncreasing;
bool IsSignedPredicate = LS.IsSignedPredicate;
IRBuilder<> B(PreheaderJump);
auto NoopOrExt = [&](Value *V) {
if (V->getType() == RangeTy)
return V;
return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
: B.CreateZExt(V, RangeTy, "wide." + V->getName());
};
// EnterLoopCond - is it okay to start executing this `LS'?
Value *EnterLoopCond = nullptr;
auto Pred =
Increasing
? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
: (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
Value *IndVarStart = NoopOrExt(LS.IndVarStart);
EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
PreheaderJump->eraseFromParent();
LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
B.SetInsertPoint(LS.LatchBr);
Value *IndVarBase = NoopOrExt(LS.IndVarBase);
Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
Value *CondForBranch = LS.LatchBrExitIdx == 1
? TakeBackedgeLoopCond
: B.CreateNot(TakeBackedgeLoopCond);
LS.LatchBr->setCondition(CondForBranch);
B.SetInsertPoint(RRI.ExitSelector);
// IterationsLeft - are there any more iterations left, given the original
// upper bound on the induction variable? If not, we branch to the "real"
// exit.
Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
BranchInst *BranchToContinuation =
BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
// We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
// each of the PHI nodes in the loop header. This feeds into the initial
// value of the same PHI nodes if/when we continue execution.
for (PHINode &PN : LS.Header->phis()) {
PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
BranchToContinuation->getIterator());
NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
RRI.ExitSelector);
RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
}
RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
BranchToContinuation->getIterator());
RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
// The latch exit now has a branch from `RRI.ExitSelector' instead of
// `LS.Latch'. The PHI nodes need to be updated to reflect that.
LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
return RRI;
}
void LoopConstrainer::rewriteIncomingValuesForPHIs(
LoopStructure &LS, BasicBlock *ContinuationBlock,
const LoopConstrainer::RewrittenRangeInfo &RRI) const {
unsigned PHIIndex = 0;
for (PHINode &PN : LS.Header->phis())
PN.setIncomingValueForBlock(ContinuationBlock,
RRI.PHIValuesAtPseudoExit[PHIIndex++]);
LS.IndVarStart = RRI.IndVarEnd;
}
BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
BasicBlock *OldPreheader,
const char *Tag) const {
BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
BranchInst::Create(LS.Header, Preheader);
LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
return Preheader;
}
void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
Loop *ParentLoop = OriginalLoop.getParentLoop();
if (!ParentLoop)
return;
for (BasicBlock *BB : BBs)
ParentLoop->addBasicBlockToLoop(BB, LI);
}
Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
ValueToValueMapTy &VM,
bool IsSubloop) {
Loop &New = *LI.AllocateLoop();
if (Parent)
Parent->addChildLoop(&New);
else
LI.addTopLevelLoop(&New);
LPMAddNewLoop(&New, IsSubloop);
// Add all of the blocks in Original to the new loop.
for (auto *BB : Original->blocks())
if (LI.getLoopFor(BB) == Original)
New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
// Add all of the subloops to the new loop.
for (Loop *SubLoop : *Original)
createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
return &New;
}
bool LoopConstrainer::run() {
BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
assert(Preheader != nullptr && "precondition!");
OriginalPreheader = Preheader;
MainLoopPreheader = Preheader;
bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
bool Increasing = MainLoopStructure.IndVarIncreasing;
IntegerType *IVTy = cast<IntegerType>(RangeTy);
SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");
Instruction *InsertPt = OriginalPreheader->getTerminator();
// It would have been better to make `PreLoop' and `PostLoop'
// `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
// constructor.
ClonedLoop PreLoop, PostLoop;
bool NeedsPreLoop =
Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
bool NeedsPostLoop =
Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
Value *ExitPreLoopAt = nullptr;
Value *ExitMainLoopAt = nullptr;
const SCEVConstant *MinusOneS =
cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
if (NeedsPreLoop) {
const SCEV *ExitPreLoopAtSCEV = nullptr;
if (Increasing)
ExitPreLoopAtSCEV = *SR.LowLimit;
else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
IsSignedPredicate))
ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
else {
LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
<< "preloop exit limit. HighLimit = "
<< *(*SR.HighLimit) << "\n");
return false;
}
if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
<< " preloop exit limit " << *ExitPreLoopAtSCEV
<< " at block " << InsertPt->getParent()->getName()
<< "\n");
return false;
}
ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
ExitPreLoopAt->setName("exit.preloop.at");
}
if (NeedsPostLoop) {
const SCEV *ExitMainLoopAtSCEV = nullptr;
if (Increasing)
ExitMainLoopAtSCEV = *SR.HighLimit;
else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
IsSignedPredicate))
ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
else {
LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
<< "mainloop exit limit. LowLimit = "
<< *(*SR.LowLimit) << "\n");
return false;
}
if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
<< " main loop exit limit " << *ExitMainLoopAtSCEV
<< " at block " << InsertPt->getParent()->getName()
<< "\n");
return false;
}
ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
ExitMainLoopAt->setName("exit.mainloop.at");
}
// We clone these ahead of time so that we don't have to deal with changing
// and temporarily invalid IR as we transform the loops.
if (NeedsPreLoop)
cloneLoop(PreLoop, "preloop");
if (NeedsPostLoop)
cloneLoop(PostLoop, "postloop");
RewrittenRangeInfo PreLoopRRI;
if (NeedsPreLoop) {
Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
PreLoop.Structure.Header);
MainLoopPreheader =
createPreheader(MainLoopStructure, Preheader, "mainloop");
PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
ExitPreLoopAt, MainLoopPreheader);
rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
PreLoopRRI);
}
BasicBlock *PostLoopPreheader = nullptr;
RewrittenRangeInfo PostLoopRRI;
if (NeedsPostLoop) {
PostLoopPreheader =
createPreheader(PostLoop.Structure, Preheader, "postloop");
PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
ExitMainLoopAt, PostLoopPreheader);
rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
PostLoopRRI);
}
BasicBlock *NewMainLoopPreheader =
MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
PostLoopRRI.ExitSelector, NewMainLoopPreheader};
// Some of the above may be nullptr, filter them out before passing to
// addToParentLoopIfNeeded.
auto NewBlocksEnd =
std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
DT.recalculate(F);
// We need to first add all the pre and post loop blocks into the loop
// structures (as part of createClonedLoopStructure), and then update the
// LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
// LI when LoopSimplifyForm is generated.
Loop *PreL = nullptr, *PostL = nullptr;
if (!PreLoop.Blocks.empty()) {
PreL = createClonedLoopStructure(&OriginalLoop,
OriginalLoop.getParentLoop(), PreLoop.Map,
/* IsSubLoop */ false);
}
if (!PostLoop.Blocks.empty()) {
PostL =
createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
PostLoop.Map, /* IsSubLoop */ false);
}
// This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
formLCSSARecursively(*L, DT, &LI, &SE);
simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
// Pre/post loops are slow paths, we do not need to perform any loop
// optimizations on them.
if (!IsOriginalLoop)
DisableAllLoopOptsOnLoop(*L);
};
if (PreL)
CanonicalizeLoop(PreL, false);
if (PostL)
CanonicalizeLoop(PostL, false);
CanonicalizeLoop(&OriginalLoop, true);
/// At this point:
/// - We've broken a "main loop" out of the loop in a way that the "main loop"
/// runs with the induction variable in a subset of [Begin, End).
/// - There is no overflow when computing "main loop" exit limit.
/// - Max latch taken count of the loop is limited.
/// It guarantees that induction variable will not overflow iterating in the
/// "main loop".
if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
if (IsSignedPredicate)
cast<BinaryOperator>(MainLoopStructure.IndVarBase)
->setHasNoSignedWrap(true);
/// TODO: support unsigned predicate.
/// To add NUW flag we need to prove that both operands of BO are
/// non-negative. E.g:
/// ...
/// %iv.next = add nsw i32 %iv, -1
/// %cmp = icmp ult i32 %iv.next, %n
/// br i1 %cmp, label %loopexit, label %loop
///
/// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
/// overflow, therefore NUW flag is not legal here.
return true;
}