[ValueTracking] don't recursively compute known bits using multiple llvm.assumes

This is an alternative to D99759 to avoid the compile-time explosion seen in:
https://llvm.org/PR49785

Another potential solution would make the exclusion logic stronger to avoid
blowing up, but note that we reduced the complexity of the exclusion mechanism
in D16204 because it was too costly.

So I'm questioning the need for recursion/exclusion entirely - what is the
optimization value vs. cost of recursively computing known bits based on
assumptions?
This was built into the implementation from the start with 60db058,
and we have kept adding code/cost to deal with that capability.

By clearing the query's AssumptionCache inside computeKnownBitsFromAssume(),
this patch retains all existing assume functionality except refining known
bits based on even more assumptions.

We have 1 regression test that shows a difference in optimization power.

Differential Revision: https://reviews.llvm.org/D100573

GitOrigin-RevId: bb907b26e2bf2a0d5ea3fcb13ba0ccef3420a38c
diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp
index d8963c6..46aa84f 100644
--- a/lib/Analysis/ValueTracking.cpp
+++ b/lib/Analysis/ValueTracking.cpp
@@ -107,40 +107,13 @@
   // provide it currently.
   OptimizationRemarkEmitter *ORE;
 
-  /// Set of assumptions that should be excluded from further queries.
-  /// This is because of the potential for mutual recursion to cause
-  /// computeKnownBits to repeatedly visit the same assume intrinsic. The
-  /// classic case of this is assume(x = y), which will attempt to determine
-  /// bits in x from bits in y, which will attempt to determine bits in y from
-  /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call
-  /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo
-  /// (all of which can call computeKnownBits), and so on.
-  std::array<const Value *, MaxAnalysisRecursionDepth> Excluded;
-
   /// If true, it is safe to use metadata during simplification.
   InstrInfoQuery IIQ;
 
-  unsigned NumExcluded = 0;
-
   Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI,
         const DominatorTree *DT, bool UseInstrInfo,
         OptimizationRemarkEmitter *ORE = nullptr)
       : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {}
-
-  Query(const Query &Q, const Value *NewExcl)
-      : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ),
-        NumExcluded(Q.NumExcluded) {
-    Excluded = Q.Excluded;
-    Excluded[NumExcluded++] = NewExcl;
-    assert(NumExcluded <= Excluded.size());
-  }
-
-  bool isExcluded(const Value *Value) const {
-    if (NumExcluded == 0)
-      return false;
-    auto End = Excluded.begin() + NumExcluded;
-    return std::find(Excluded.begin(), End, Value) != End;
-  }
 };
 
 } // end anonymous namespace
@@ -632,8 +605,6 @@
     CallInst *I = cast<CallInst>(AssumeVH);
     assert(I->getFunction() == Q.CxtI->getFunction() &&
            "Got assumption for the wrong function!");
-    if (Q.isExcluded(I))
-      continue;
 
     // Warning: This loop can end up being somewhat performance sensitive.
     // We're running this loop for once for each value queried resulting in a
@@ -681,8 +652,6 @@
     CallInst *I = cast<CallInst>(AssumeVH);
     assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
            "Got assumption for the wrong function!");
-    if (Q.isExcluded(I))
-      continue;
 
     // Warning: This loop can end up being somewhat performance sensitive.
     // We're running this loop for once for each value queried resulting in a
@@ -713,6 +682,15 @@
     if (!Cmp)
       continue;
 
+    // We are attempting to compute known bits for the operands of an assume.
+    // Do not try to use other assumptions for those recursive calls because
+    // that can lead to mutual recursion and a compile-time explosion.
+    // An example of the mutual recursion: computeKnownBits can call
+    // isKnownNonZero which calls computeKnownBitsFromAssume (this function)
+    // and so on.
+    Query QueryNoAC = Q;
+    QueryNoAC.AC = nullptr;
+
     // Note that ptrtoint may change the bitwidth.
     Value *A, *B;
     auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V)));
@@ -727,7 +705,7 @@
       if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         Known.Zero |= RHSKnown.Zero;
         Known.One  |= RHSKnown.One;
       // assume(v & b = a)
@@ -735,9 +713,9 @@
                        m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits MaskKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in the mask that are known to be one, we can propagate
         // known bits from the RHS to V.
@@ -748,9 +726,9 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits MaskKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in the mask that are known to be one, we can propagate
         // inverted known bits from the RHS to V.
@@ -761,9 +739,9 @@
                        m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits BKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate known
         // bits from the RHS to V.
@@ -774,9 +752,9 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits BKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate
         // inverted known bits from the RHS to V.
@@ -787,9 +765,9 @@
                        m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits BKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate known
         // bits from the RHS to V. For those bits in B that are known to be one,
@@ -803,9 +781,9 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         KnownBits BKnown =
-            computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in B that are known to be zero, we can propagate
         // inverted known bits from the RHS to V. For those bits in B that are
@@ -819,7 +797,7 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // For those bits in RHS that are known, we can propagate them to known
         // bits in V shifted to the right by C.
@@ -832,7 +810,7 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them inverted
         // to known bits in V shifted to the right by C.
         RHSKnown.One.lshrInPlace(C);
@@ -844,7 +822,7 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them to known
         // bits in V shifted to the right by C.
         Known.Zero |= RHSKnown.Zero << C;
@@ -854,7 +832,7 @@
                                      m_Value(A))) &&
                  isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
         // For those bits in RHS that are known, we can propagate them inverted
         // to known bits in V shifted to the right by C.
         Known.Zero |= RHSKnown.One  << C;
@@ -866,7 +844,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isNonNegative()) {
           // We know that the sign bit is zero.
@@ -879,7 +857,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) {
           // We know that the sign bit is zero.
@@ -892,7 +870,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isNegative()) {
           // We know that the sign bit is one.
@@ -905,7 +883,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         if (RHSKnown.isZero() || RHSKnown.isNegative()) {
           // We know that the sign bit is one.
@@ -918,7 +896,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // Whatever high bits in c are zero are known to be zero.
         Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
@@ -929,7 +907,7 @@
       if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) &&
           isValidAssumeForContext(I, Q.CxtI, Q.DT)) {
         KnownBits RHSKnown =
-            computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth);
+            computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth);
 
         // If the RHS is known zero, then this assumption must be wrong (nothing
         // is unsigned less than zero). Signal a conflict and get out of here.
@@ -941,7 +919,7 @@
 
         // Whatever high bits in c are zero are known to be zero (if c is a power
         // of 2, then one more).
-        if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I)))
+        if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC))
           Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1);
         else
           Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros());
diff --git a/test/Transforms/InstCombine/assume.ll b/test/Transforms/InstCombine/assume.ll
index d25f1d6..403d82c 100644
--- a/test/Transforms/InstCombine/assume.ll
+++ b/test/Transforms/InstCombine/assume.ll
@@ -175,15 +175,20 @@
   ret i32 %and1
 }
 
-define i32 @bar4(i32 %a, i32 %b) {
-; CHECK-LABEL: @bar4(
+; If we allow recursive known bits queries based on
+; assumptions, we could do better here:
+; a == b and a & 7 == 1, so b & 7 == 1, so b & 3 == 1, so return 1.
+
+define i32 @known_bits_recursion_via_assumes(i32 %a, i32 %b) {
+; CHECK-LABEL: @known_bits_recursion_via_assumes(
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AND1:%.*]] = and i32 [[B:%.*]], 3
 ; CHECK-NEXT:    [[AND:%.*]] = and i32 [[A:%.*]], 7
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[AND]], 1
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP]])
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[A]], [[B:%.*]]
+; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i32 [[A]], [[B]]
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[CMP2]])
-; CHECK-NEXT:    ret i32 1
+; CHECK-NEXT:    ret i32 [[AND1]]
 ;
 entry:
   %and1 = and i32 %b, 3