[InstCombine] refactor folds for (icmp (bitcast X), Y); NFCI


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@353462 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 399aa2d..5a63d55 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2490,20 +2490,72 @@
     // the entire original Cmp can be simplified to a false.
     Value *Cond = Builder.getFalse();
     if (TrueWhenLessThan)
-      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS));
+      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT,
+                                                       OrigLHS, OrigRHS));
     if (TrueWhenEqual)
-      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS));
+      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ,
+                                                       OrigLHS, OrigRHS));
     if (TrueWhenGreaterThan)
-      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS));
+      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT,
+                                                       OrigLHS, OrigRHS));
 
     return replaceInstUsesWith(Cmp, Cond);
   }
   return nullptr;
 }
 
-Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp,
-                                                   BitCastInst *Bitcast,
-                                                   const APInt &C) {
+static Instruction *foldICmpBitCast(ICmpInst &Cmp,
+                                    InstCombiner::BuilderTy &Builder) {
+  auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0));
+  if (!Bitcast)
+    return nullptr;
+
+  // Zero-equality and sign-bit checks are preserved through sitofp + bitcast.
+  // FIXME: This needs to check that the bitcast does not change the number of
+  // elements in a vector.
+  ICmpInst::Predicate Pred = Cmp.getPredicate();
+  Value *Op1 = Cmp.getOperand(1);
+  Value *BCSrcOp = Bitcast->getOperand(0);
+  Value *X;
+  if (match(BCSrcOp, m_SIToFP(m_Value(X)))) {
+    // icmp  eq (bitcast (sitofp X)), 0 --> icmp  eq X, 0
+    // icmp  ne (bitcast (sitofp X)), 0 --> icmp  ne X, 0
+    // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0
+    // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0
+    if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT ||
+         Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) &&
+        match(Op1, m_Zero()))
+      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
+
+    // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1
+    if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One()))
+      return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1));
+
+    // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1
+    if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))
+      return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType()));
+  }
+
+  // Zero-equality checks are preserved through unsigned floating-point casts:
+  // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0
+  // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0
+  if (match(BCSrcOp, m_UIToFP(m_Value(X))))
+    if (Cmp.isEquality() && match(Op1, m_Zero()))
+      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
+
+  // Test to see if the operands of the icmp are casted versions of other
+  // values. If the ptr->ptr cast can be stripped off both arguments, do so.
+  if (Bitcast->getType()->isPointerTy() &&
+      (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) {
+    // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast
+    // so eliminate it as well.
+    if (auto *BC2 = dyn_cast<BitCastInst>(Op1))
+      Op1 = BC2->getOperand(0);
+
+    Op1 = Builder.CreateBitCast(Op1, BCSrcOp->getType());
+    return new ICmpInst(Pred, BCSrcOp, Op1);
+  }
+
   // Folding: icmp <pred> iN X, C
   //  where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN
   //    and C is a splat of a K-bit pattern
@@ -2511,28 +2563,28 @@
   // Into:
   //   %E = extractelement <M x iK> %vec, i32 C'
   //   icmp <pred> iK %E, trunc(C)
-  if (!Bitcast->getType()->isIntegerTy() ||
+  const APInt *C;
+  if (!match(Cmp.getOperand(1), m_APInt(C)) ||
+      !Bitcast->getType()->isIntegerTy() ||
       !Bitcast->getSrcTy()->isIntOrIntVectorTy())
     return nullptr;
 
-  Value *BCIOp = Bitcast->getOperand(0);
-  Value *Vec = nullptr;     // 1st vector arg of the shufflevector
-  Constant *Mask = nullptr; // Mask arg of the shufflevector
-  if (match(BCIOp,
+  Value *Vec;
+  Constant *Mask;
+  if (match(BCSrcOp,
             m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) {
     // Check whether every element of Mask is the same constant
     if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) {
-      auto *VecTy = cast<VectorType>(BCIOp->getType());
+      auto *VecTy = cast<VectorType>(BCSrcOp->getType());
       auto *EltTy = cast<IntegerType>(VecTy->getElementType());
-      auto Pred = Cmp.getPredicate();
-      if (C.isSplat(EltTy->getBitWidth())) {
+      if (C->isSplat(EltTy->getBitWidth())) {
         // Fold the icmp based on the value of C
         // If C is M copies of an iK sized bit pattern,
         // then:
         //   =>  %E = extractelement <N x iK> %vec, i32 Elem
         //       icmp <pred> iK %SplatVal, <pattern>
         Value *Extract = Builder.CreateExtractElement(Vec, Elem);
-        Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth()));
+        Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth()));
         return new ICmpInst(Pred, Extract, NewC);
       }
     }
@@ -2614,11 +2666,6 @@
       return I;
   }
 
-  if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) {
-    if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C))
-      return I;
-  }
-
   if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)))
     if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C))
       return I;
@@ -4945,61 +4992,8 @@
         return New;
   }
 
-  // Zero-equality and sign-bit checks are preserved through sitofp + bitcast.
-  Value *X;
-  if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) {
-    // icmp  eq (bitcast (sitofp X)), 0 --> icmp  eq X, 0
-    // icmp  ne (bitcast (sitofp X)), 0 --> icmp  ne X, 0
-    // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0
-    // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0
-    if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT ||
-         Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) &&
-        match(Op1, m_Zero()))
-      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
-
-    // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1
-    if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One()))
-      return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1));
-
-    // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1
-    if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))
-      return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType()));
-  }
-
-  // Zero-equality checks are preserved through unsigned floating-point casts:
-  // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0
-  // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0
-  if (match(Op0, m_BitCast(m_UIToFP(m_Value(X)))))
-    if (I.isEquality() && match(Op1, m_Zero()))
-      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
-
-  // Test to see if the operands of the icmp are casted versions of other
-  // values.  If the ptr->ptr cast can be stripped off both arguments, we do so
-  // now.
-  if (BitCastInst *CI = dyn_cast<BitCastInst>(Op0)) {
-    if (Op0->getType()->isPointerTy() &&
-        (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) {
-      // We keep moving the cast from the left operand over to the right
-      // operand, where it can often be eliminated completely.
-      Op0 = CI->getOperand(0);
-
-      // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast
-      // so eliminate it as well.
-      if (BitCastInst *CI2 = dyn_cast<BitCastInst>(Op1))
-        Op1 = CI2->getOperand(0);
-
-      // If Op1 is a constant, we can fold the cast into the constant.
-      if (Op0->getType() != Op1->getType()) {
-        if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
-          Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType());
-        } else {
-          // Otherwise, cast the RHS right before the icmp
-          Op1 = Builder.CreateBitCast(Op1, Op0->getType());
-        }
-      }
-      return new ICmpInst(I.getPredicate(), Op0, Op1);
-    }
-  }
+  if (Instruction *Res = foldICmpBitCast(I, Builder))
+    return Res;
 
   if (isa<CastInst>(Op0)) {
     // Handle the special case of: icmp (cast bool to X), <cst>
diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index 35876a6..11993a4 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -865,8 +865,6 @@
 
   Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select,
                                       ConstantInt *C);
-  Instruction *foldICmpBitCastConstant(ICmpInst &Cmp, BitCastInst *Bitcast,
-                                       const APInt &C);
   Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc,
                                      const APInt &C);
   Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And,