[EarlyCSE] detect equivalence of selects with inverse conditions and commuted operands (PR41101)

This is 1 of the problems discussed in the post-commit thread for:
rL355741 / http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20190311/635516.html
and filed as:
https://bugs.llvm.org/show_bug.cgi?id=41101

Instcombine tries to canonicalize some of these cases (and there's room for improvement
there independently of this patch), but it can't always do that because of extra uses.
So we need to recognize these commuted operand patterns here in EarlyCSE. This is similar
to how we detect commuted compares and commuted min/max/abs.

Differential Revision: https://reviews.llvm.org/D60723

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@358523 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp
index eb81ac6..892a123 100644
--- a/lib/Transforms/Scalar/EarlyCSE.cpp
+++ b/lib/Transforms/Scalar/EarlyCSE.cpp
@@ -130,6 +130,21 @@
 
 } // end namespace llvm
 
+/// Match a 'select' including an optional 'not' of the condition.
+static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond,
+                                           Value *&T, Value *&F) {
+  if (match(V, m_Select(m_Value(Cond), m_Value(T), m_Value(F)))) {
+    // Look through a 'not' of the condition operand by swapping true/false.
+    Value *CondNot;
+    if (match(Cond, m_Not(m_Value(CondNot)))) {
+      Cond = CondNot;
+      std::swap(T, F);
+    }
+    return true;
+  }
+  return false;
+}
+
 unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
   Instruction *Inst = Val.Inst;
   // Hash in all of the operands as pointers.
@@ -171,6 +186,24 @@
     return hash_combine(Inst->getOpcode(), SPF, A, B);
   }
 
+  // Hash general selects to allow matching commuted true/false operands.
+  Value *Cond, *TVal, *FVal;
+  if (matchSelectWithOptionalNotCond(Inst, Cond, TVal, FVal)) {
+    // If we do not have a compare as the condition, just hash in the condition.
+    CmpInst::Predicate Pred;
+    Value *X, *Y;
+    if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y))))
+      return hash_combine(Inst->getOpcode(), Cond, TVal, FVal);
+
+    // Similar to cmp normalization (above) - canonicalize the predicate value:
+    // select (icmp Pred, X, Y), T, F --> select (icmp InvPred, X, Y), F, T
+    if (CmpInst::getInversePredicate(Pred) < Pred) {
+      Pred = CmpInst::getInversePredicate(Pred);
+      std::swap(TVal, FVal);
+    }
+    return hash_combine(Inst->getOpcode(), Pred, X, Y, TVal, FVal);
+  }
+
   if (CastInst *CI = dyn_cast<CastInst>(Inst))
     return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0));
 
@@ -183,8 +216,7 @@
                         IVI->getOperand(1),
                         hash_combine_range(IVI->idx_begin(), IVI->idx_end()));
 
-  assert((isa<CallInst>(Inst) || isa<BinaryOperator>(Inst) ||
-          isa<GetElementPtrInst>(Inst) || isa<SelectInst>(Inst) ||
+  assert((isa<CallInst>(Inst) || isa<GetElementPtrInst>(Inst) ||
           isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
           isa<ShuffleVectorInst>(Inst)) &&
          "Invalid/unknown instruction");
@@ -248,6 +280,31 @@
     }
   }
 
+  // Selects can be non-trivially equivalent via inverted conditions and swaps.
+  Value *CondL, *CondR, *TrueL, *TrueR, *FalseL, *FalseR;
+  if (matchSelectWithOptionalNotCond(LHSI, CondL, TrueL, FalseL) &&
+      matchSelectWithOptionalNotCond(RHSI, CondR, TrueR, FalseR)) {
+    // select Cond, T, F <--> select not(Cond), F, T
+    if (CondL == CondR && TrueL == TrueR && FalseL == FalseR)
+      return true;
+
+    // If the true/false operands are swapped and the conditions are compares
+    // with inverted predicates, the selects are equal:
+    // select (icmp Pred, X, Y), T, F <--> select (icmp InvPred, X, Y), F, T
+    //
+    // This also handles patterns with a double-negation because we looked
+    // through a 'not' in the matching function and swapped T/F:
+    // select (cmp Pred, X, Y), T, F <--> select (not (cmp InvPred, X, Y)), T, F
+    if (TrueL == FalseR && FalseL == TrueR) {
+      CmpInst::Predicate PredL, PredR;
+      Value *X, *Y;
+      if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) &&
+          match(CondR, m_Cmp(PredR, m_Specific(X), m_Specific(Y))) &&
+          CmpInst::getInversePredicate(PredL) == PredR)
+        return true;
+    }
+  }
+
   return false;
 }
 
diff --git a/test/Transforms/EarlyCSE/commute.ll b/test/Transforms/EarlyCSE/commute.ll
index 5ffcbf4..488acf6 100644
--- a/test/Transforms/EarlyCSE/commute.ll
+++ b/test/Transforms/EarlyCSE/commute.ll
@@ -290,15 +290,13 @@
 }
 
 ; https://bugs.llvm.org/show_bug.cgi?id=41101
-; TODO: Detect equivalence of selects with commuted operands: 'not' cond.
+; Detect equivalence of selects with commuted operands: 'not' cond.
 
 define i32 @select_not_cond(i1 %cond, i32 %t, i32 %f) {
 ; CHECK-LABEL: @select_not_cond(
 ; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[COND:%.*]], true
 ; CHECK-NEXT:    [[M1:%.*]] = select i1 [[COND]], i32 [[T:%.*]], i32 [[F:%.*]]
-; CHECK-NEXT:    [[M2:%.*]] = select i1 [[NOT]], i32 [[F]], i32 [[T]]
-; CHECK-NEXT:    [[R:%.*]] = xor i32 [[M2]], [[M1]]
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 0
 ;
   %not = xor i1 %cond, -1
   %m1 = select i1 %cond, i32 %t, i32 %f
@@ -307,15 +305,13 @@
   ret i32 %r
 }
 
-; TODO: Detect equivalence of selects with commuted operands: 'not' cond with vector select.
+; Detect equivalence of selects with commuted operands: 'not' cond with vector select.
 
 define <2 x double> @select_not_cond_commute_vec(<2 x i1> %cond, <2 x double> %t, <2 x double> %f) {
 ; CHECK-LABEL: @select_not_cond_commute_vec(
 ; CHECK-NEXT:    [[NOT:%.*]] = xor <2 x i1> [[COND:%.*]], <i1 true, i1 true>
 ; CHECK-NEXT:    [[M1:%.*]] = select <2 x i1> [[COND]], <2 x double> [[T:%.*]], <2 x double> [[F:%.*]]
-; CHECK-NEXT:    [[M2:%.*]] = select <2 x i1> [[NOT]], <2 x double> [[F]], <2 x double> [[T]]
-; CHECK-NEXT:    [[R:%.*]] = fdiv nnan <2 x double> [[M1]], [[M2]]
-; CHECK-NEXT:    ret <2 x double> [[R]]
+; CHECK-NEXT:    ret <2 x double> <double 1.000000e+00, double 1.000000e+00>
 ;
   %not = xor <2 x i1> %cond, <i1 -1, i1 -1>
   %m1 = select <2 x i1> %cond, <2 x double> %t, <2 x double> %f
@@ -357,16 +353,14 @@
   ret i32 %r
 }
 
-; TODO: Detect equivalence of selects with commuted operands: inverted pred with fcmps.
+; Detect equivalence of selects with commuted operands: inverted pred with fcmps.
 
 define i32 @select_invert_pred_cond(float %x, i32 %t, i32 %f) {
 ; CHECK-LABEL: @select_invert_pred_cond(
 ; CHECK-NEXT:    [[COND:%.*]] = fcmp ueq float [[X:%.*]], 4.200000e+01
 ; CHECK-NEXT:    [[INVCOND:%.*]] = fcmp one float [[X]], 4.200000e+01
 ; CHECK-NEXT:    [[M1:%.*]] = select i1 [[COND]], i32 [[T:%.*]], i32 [[F:%.*]]
-; CHECK-NEXT:    [[M2:%.*]] = select i1 [[INVCOND]], i32 [[F]], i32 [[T]]
-; CHECK-NEXT:    [[R:%.*]] = xor i32 [[M2]], [[M1]]
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 0
 ;
   %cond = fcmp ueq float %x, 42.0
   %invcond = fcmp one float %x, 42.0
@@ -376,16 +370,14 @@
   ret i32 %r
 }
 
-; TODO: Detect equivalence of selects with commuted operands: inverted pred with icmps and vectors.
+; Detect equivalence of selects with commuted operands: inverted pred with icmps and vectors.
 
 define <2 x i32> @select_invert_pred_cond_commute_vec(<2 x i8> %x, <2 x i32> %t, <2 x i32> %f) {
 ; CHECK-LABEL: @select_invert_pred_cond_commute_vec(
 ; CHECK-NEXT:    [[COND:%.*]] = icmp sgt <2 x i8> [[X:%.*]], <i8 42, i8 -1>
 ; CHECK-NEXT:    [[INVCOND:%.*]] = icmp sle <2 x i8> [[X]], <i8 42, i8 -1>
 ; CHECK-NEXT:    [[M1:%.*]] = select <2 x i1> [[COND]], <2 x i32> [[T:%.*]], <2 x i32> [[F:%.*]]
-; CHECK-NEXT:    [[M2:%.*]] = select <2 x i1> [[INVCOND]], <2 x i32> [[F]], <2 x i32> [[T]]
-; CHECK-NEXT:    [[R:%.*]] = xor <2 x i32> [[M1]], [[M2]]
-; CHECK-NEXT:    ret <2 x i32> [[R]]
+; CHECK-NEXT:    ret <2 x i32> zeroinitializer
 ;
   %cond = icmp sgt <2 x i8> %x, <i8 42, i8 -1>
   %invcond = icmp sle <2 x i8> %x, <i8 42, i8 -1>
@@ -452,7 +444,7 @@
   ret i32 %r
 }
 
-; TODO: If we have both an inverted predicate and a 'not' op, recognize the double-negation.
+; If we have both an inverted predicate and a 'not' op, recognize the double-negation.
 
 define i32 @select_not_invert_pred_cond(i8 %x, i32 %t, i32 %f) {
 ; CHECK-LABEL: @select_not_invert_pred_cond(
@@ -460,9 +452,7 @@
 ; CHECK-NEXT:    [[INVCOND:%.*]] = icmp ule i8 [[X]], 42
 ; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[INVCOND]], true
 ; CHECK-NEXT:    [[M1:%.*]] = select i1 [[COND]], i32 [[T:%.*]], i32 [[F:%.*]]
-; CHECK-NEXT:    [[M2:%.*]] = select i1 [[NOT]], i32 [[T]], i32 [[F]]
-; CHECK-NEXT:    [[R:%.*]] = sub i32 [[M1]], [[M2]]
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 0
 ;
   %cond = icmp ugt i8 %x, 42
   %invcond = icmp ule i8 %x, 42
@@ -473,7 +463,7 @@
   ret i32 %r
 }
 
-; TODO: If we have both an inverted predicate and a 'not' op, recognize the double-negation.
+; If we have both an inverted predicate and a 'not' op, recognize the double-negation.
 
 define i32 @select_not_invert_pred_cond_commute(i8 %x, i8 %y, i32 %t, i32 %f) {
 ; CHECK-LABEL: @select_not_invert_pred_cond_commute(
@@ -481,9 +471,7 @@
 ; CHECK-NEXT:    [[NOT:%.*]] = xor i1 [[INVCOND]], true
 ; CHECK-NEXT:    [[M2:%.*]] = select i1 [[NOT]], i32 [[T:%.*]], i32 [[F:%.*]]
 ; CHECK-NEXT:    [[COND:%.*]] = icmp ugt i8 [[X]], [[Y]]
-; CHECK-NEXT:    [[M1:%.*]] = select i1 [[COND]], i32 [[T]], i32 [[F]]
-; CHECK-NEXT:    [[R:%.*]] = sub i32 [[M2]], [[M1]]
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    ret i32 0
 ;
   %invcond = icmp ule i8 %x, %y
   %not = xor i1 %invcond, -1