[LoopPred] Stop passing around builders [NFC]

This is a preparatory patch for D60093. This patch itself is NFC, but while preparing this I noticed and committed a small hoisting change in rL358419.

The basic structure of the new scheme is that we pass around the guard ("the using instruction"), and select an optimal insert point by examining operands at each construction point. This seems conceptually a bit cleaner to start with as it isolates the knowledge about insertion safety at the actual insertion point.

Note that the non-hoisting path is not actually used at the moment. That's not exercised until D60093 is rebased on this one.

Differential Revision: https://reviews.llvm.org/D60718



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@358434 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp
index 3b27364..de148ff 100644
--- a/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/lib/Transforms/Scalar/LoopPredication.cpp
@@ -269,24 +269,29 @@
   /// trivial result would be the at the User itself, but we try to return a
   /// loop invariant location if possible.  
   Instruction *findInsertPt(Instruction *User, ArrayRef<Value*> Ops);
+  /// Same as above, *except* that this uses the SCEV definition of invariant
+  /// which is that an expression *can be made* invariant via SCEVExpander.
+  /// Thus, this version is only suitable for finding an insert point to be be
+  /// passed to SCEVExpander!
+  Instruction *findInsertPt(Instruction *User, ArrayRef<const SCEV*> Ops);
 
   bool CanExpand(const SCEV* S);
-  Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder,
+  Value *expandCheck(SCEVExpander &Expander, Instruction *Guard,
                      ICmpInst::Predicate Pred, const SCEV *LHS,
                      const SCEV *RHS);
 
   Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander,
-                                        IRBuilder<> &Builder);
+                                        Instruction *Guard);
   Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck,
                                                         LoopICmp RangeCheck,
                                                         SCEVExpander &Expander,
-                                                        IRBuilder<> &Builder);
+                                                        Instruction *Guard);
   Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck,
                                                         LoopICmp RangeCheck,
                                                         SCEVExpander &Expander,
-                                                        IRBuilder<> &Builder);
+                                                        Instruction *Guard);
   unsigned collectChecks(SmallVectorImpl<Value *> &Checks, Value *Condition,
-                         SCEVExpander &Expander, IRBuilder<> &Builder);
+                         SCEVExpander &Expander, Instruction *Guard);
   bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander);
   bool widenWidenableBranchGuardConditions(BranchInst *Guard, SCEVExpander &Expander);
   // If the loop always exits through another block in the loop, we should not
@@ -394,21 +399,24 @@
 }
 
 Value *LoopPredication::expandCheck(SCEVExpander &Expander,
-                                    IRBuilder<> &Builder,
+                                    Instruction *Guard, 
                                     ICmpInst::Predicate Pred, const SCEV *LHS,
                                     const SCEV *RHS) {
   Type *Ty = LHS->getType();
   assert(Ty == RHS->getType() && "expandCheck operands have different types?");
 
-  if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
-    return Builder.getTrue();
-  if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
-                                   LHS, RHS))
-    return Builder.getFalse();
+  if (SE->isLoopInvariant(LHS, L) && SE->isLoopInvariant(RHS, L)) {
+    IRBuilder<> Builder(Guard);
+    if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS))
+      return Builder.getTrue();
+    if (SE->isLoopEntryGuardedByCond(L, ICmpInst::getInversePredicate(Pred),
+                                     LHS, RHS))
+      return Builder.getFalse();
+  }
 
-  Instruction *InsertAt = &*Builder.GetInsertPoint();
-  Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt);
-  Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt);
+  Value *LHSV = Expander.expandCodeFor(LHS, Ty, findInsertPt(Guard, {LHS}));
+  Value *RHSV = Expander.expandCodeFor(RHS, Ty, findInsertPt(Guard, {RHS}));
+  IRBuilder<> Builder(findInsertPt(Guard, {LHSV, RHSV}));
   return Builder.CreateICmp(Pred, LHSV, RHSV);
 }
 
@@ -452,13 +460,22 @@
   return Preheader->getTerminator();
 }
 
+Instruction *LoopPredication::findInsertPt(Instruction *Use,
+                                           ArrayRef<const SCEV*> Ops) {
+  for (const SCEV *Op : Ops)
+    if (!SE->isLoopInvariant(Op, L))
+      return Use;
+  return Preheader->getTerminator();
+}
+
+
 bool LoopPredication::CanExpand(const SCEV* S) {
   return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE);
 }
 
 Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop(
     LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
-    SCEVExpander &Expander, IRBuilder<> &Builder) {
+    SCEVExpander &Expander, Instruction *Guard) {
   auto *Ty = RangeCheck.IV->getType();
   // Generate the widened condition for the forward loop:
   //   guardStart u< guardLimit &&
@@ -488,15 +505,16 @@
   LLVM_DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n");
  
   auto *LimitCheck =
-      expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS);
-  auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred,
+      expandCheck(Expander, Guard, LimitCheckPred, LatchLimit, RHS);
+  auto *FirstIterationCheck = expandCheck(Expander, Guard, RangeCheck.Pred,
                                           GuardStart, GuardLimit);
+  IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
   return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
 }
 
 Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop(
     LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck,
-    SCEVExpander &Expander, IRBuilder<> &Builder) {
+    SCEVExpander &Expander, Instruction *Guard) {
   auto *Ty = RangeCheck.IV->getType();
   const SCEV *GuardStart = RangeCheck.IV->getStart();
   const SCEV *GuardLimit = RangeCheck.Limit;
@@ -522,10 +540,12 @@
   // See the header comment for reasoning of the checks.
   auto LimitCheckPred =
       ICmpInst::getFlippedStrictnessPredicate(LatchCheck.Pred);
-  auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT,
+  auto *FirstIterationCheck = expandCheck(Expander, Guard,
+                                          ICmpInst::ICMP_ULT,
                                           GuardStart, GuardLimit);
-  auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit,
+  auto *LimitCheck = expandCheck(Expander, Guard, LimitCheckPred, LatchLimit,
                                  SE->getOne(Ty));
+  IRBuilder<> Builder(findInsertPt(Guard, {FirstIterationCheck, LimitCheck}));
   return Builder.CreateAnd(FirstIterationCheck, LimitCheck);
 }
 
@@ -534,7 +554,7 @@
 /// returns None.
 Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI,
                                                        SCEVExpander &Expander,
-                                                       IRBuilder<> &Builder) {
+                                                       Instruction *Guard) {
   LLVM_DEBUG(dbgs() << "Analyzing ICmpInst condition:\n");
   LLVM_DEBUG(ICI->dump());
 
@@ -588,18 +608,18 @@
 
   if (Step->isOne())
     return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck,
-                                               Expander, Builder);
+                                               Expander, Guard);
   else {
     assert(Step->isAllOnesValue() && "Step should be -1!");
     return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck,
-                                               Expander, Builder);
+                                               Expander, Guard);
   }
 }
 
 unsigned LoopPredication::collectChecks(SmallVectorImpl<Value *> &Checks,
                                         Value *Condition,
                                         SCEVExpander &Expander,
-                                        IRBuilder<> &Builder) {
+                                        Instruction *Guard) {
   unsigned NumWidened = 0;
   // The guard condition is expected to be in form of:
   //   cond1 && cond2 && cond3 ...
@@ -631,7 +651,7 @@
 
     if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) {
       if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander,
-                                                   Builder)) {
+                                                   Guard)) {
         Checks.push_back(NewRangeCheck.getValue());
         NumWidened++;
         continue;
@@ -657,16 +677,15 @@
 
   TotalConsidered++;
   SmallVector<Value *, 4> Checks;
-  IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
   unsigned NumWidened = collectChecks(Checks, Guard->getOperand(0), Expander,
-                                      Builder);
+                                      Guard);
   if (NumWidened == 0)
     return false;
 
   TotalWidened += NumWidened;
 
   // Emit the new guard condition
-  Builder.SetInsertPoint(findInsertPt(Guard, Checks));
+  IRBuilder<> Builder(findInsertPt(Guard, Checks));
   Value *LastCheck = nullptr;
   for (auto *Check : Checks)
     if (!LastCheck)
@@ -689,16 +708,15 @@
 
   TotalConsidered++;
   SmallVector<Value *, 4> Checks;
-  IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator()));
   unsigned NumWidened = collectChecks(Checks, BI->getCondition(),
-                                      Expander, Builder);
+                                      Expander, BI);
   if (NumWidened == 0)
     return false;
 
   TotalWidened += NumWidened;
 
   // Emit the new guard condition
-  Builder.SetInsertPoint(findInsertPt(BI, Checks));
+  IRBuilder<> Builder(findInsertPt(BI, Checks));
   Value *LastCheck = nullptr;
   for (auto *Check : Checks)
     if (!LastCheck)