Update m_Undef to match vectors/aggrs with undefs and poisons mixed

This fixes https://reviews.llvm.org/D93990#2666922
by teaching `m_Undef` to match vectors/aggrs with poison elements.

As suggested, fixes in InstCombine files to use the `m_Undef` matcher instead
of `isa<UndefValue>` will be followed.

Reviewed By: lebedev.ri

Differential Revision: https://reviews.llvm.org/D100122

GitOrigin-RevId: 2813acb7d1f70ec1d4a2a1360028950a0bc2ac8f
diff --git a/include/llvm/Analysis/InstructionSimplify.h b/include/llvm/Analysis/InstructionSimplify.h
index dda90e8..e1e7da1 100644
--- a/include/llvm/Analysis/InstructionSimplify.h
+++ b/include/llvm/Analysis/InstructionSimplify.h
@@ -37,6 +37,7 @@
 
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 
 namespace llvm {
 
@@ -133,7 +134,9 @@
   bool isUndefValue(Value *V) const {
     if (!CanUseUndef)
       return false;
-    return isa<UndefValue>(V);
+
+    using namespace PatternMatch;
+    return match(V, m_Undef());
   }
 };
 
diff --git a/include/llvm/IR/PatternMatch.h b/include/llvm/IR/PatternMatch.h
index 440e449..cdc8d27 100644
--- a/include/llvm/IR/PatternMatch.h
+++ b/include/llvm/IR/PatternMatch.h
@@ -88,8 +88,52 @@
 /// Matches any compare instruction and ignore it.
 inline class_match<CmpInst> m_Cmp() { return class_match<CmpInst>(); }
 
-/// Match an arbitrary undef constant.
-inline class_match<UndefValue> m_Undef() { return class_match<UndefValue>(); }
+struct undef_match {
+  static bool check(const Value *V) {
+    if (isa<UndefValue>(V))
+      return true;
+
+    const auto *CA = dyn_cast<ConstantAggregate>(V);
+    if (!CA)
+      return false;
+
+    SmallPtrSet<const ConstantAggregate *, 8> Seen;
+    SmallVector<const ConstantAggregate *, 8> Worklist;
+
+    // Either UndefValue, PoisonValue, or an aggregate that only contains
+    // these is accepted by matcher.
+    // CheckValue returns false if CA cannot satisfy this constraint.
+    auto CheckValue = [&](const ConstantAggregate *CA) {
+      for (const Value *Op : CA->operand_values()) {
+        if (isa<UndefValue>(Op))
+          continue;
+
+        const auto *CA = dyn_cast<ConstantAggregate>(Op);
+        if (!CA)
+          return false;
+        if (Seen.insert(CA).second)
+          Worklist.emplace_back(CA);
+      }
+
+      return true;
+    };
+
+    if (!CheckValue(CA))
+      return false;
+
+    while (!Worklist.empty()) {
+      if (!CheckValue(Worklist.pop_back_val()))
+        return false;
+    }
+    return true;
+  }
+  template <typename ITy> bool match(ITy *V) { return check(V); }
+};
+
+/// Match an arbitrary undef constant. This matches poison as well.
+/// If this is an aggregate and contains a non-aggregate element that is
+/// neither undef nor poison, the aggregate is not matched.
+inline auto m_Undef() { return undef_match(); }
 
 /// Match an arbitrary poison constant.
 inline class_match<PoisonValue> m_Poison() { return class_match<PoisonValue>(); }
diff --git a/test/Transforms/InstSimplify/icmp-constant.ll b/test/Transforms/InstSimplify/icmp-constant.ll
index dbbce4b..77d4a38 100644
--- a/test/Transforms/InstSimplify/icmp-constant.ll
+++ b/test/Transforms/InstSimplify/icmp-constant.ll
@@ -1069,8 +1069,7 @@
 
 define <2 x i1> @heterogeneous_constvector(<2 x i8> %x) {
 ; CHECK-LABEL: @heterogeneous_constvector(
-; CHECK-NEXT:    [[C:%.*]] = icmp ult <2 x i8> [[X:%.*]], <i8 undef, i8 poison>
-; CHECK-NEXT:    ret <2 x i1> [[C]]
+; CHECK-NEXT:    ret <2 x i1> zeroinitializer
 ;
   %c = icmp ult <2 x i8> %x, <i8 undef, i8 poison>
   ret <2 x i1> %c
diff --git a/unittests/IR/PatternMatch.cpp b/unittests/IR/PatternMatch.cpp
index 57840b3..fd4e446 100644
--- a/unittests/IR/PatternMatch.cpp
+++ b/unittests/IR/PatternMatch.cpp
@@ -1022,6 +1022,37 @@
   EXPECT_TRUE(A == Val);
 }
 
+TEST_F(PatternMatchTest, UndefPoisonMix) {
+  Type *ScalarTy = IRB.getInt8Ty();
+  ArrayType *ArrTy = ArrayType::get(ScalarTy, 2);
+  StructType *StTy = StructType::get(ScalarTy, ScalarTy);
+  StructType *StTy2 = StructType::get(ScalarTy, StTy);
+  StructType *StTy3 = StructType::get(StTy, ScalarTy);
+  Constant *Zero = ConstantInt::getNullValue(ScalarTy);
+  UndefValue *U = UndefValue::get(ScalarTy);
+  UndefValue *P = PoisonValue::get(ScalarTy);
+
+  EXPECT_TRUE(match(ConstantVector::get({U, P}), m_Undef()));
+  EXPECT_TRUE(match(ConstantVector::get({P, U}), m_Undef()));
+
+  EXPECT_TRUE(match(ConstantArray::get(ArrTy, {U, P}), m_Undef()));
+  EXPECT_TRUE(match(ConstantArray::get(ArrTy, {P, U}), m_Undef()));
+
+  auto *UP = ConstantStruct::get(StTy, {U, P});
+  EXPECT_TRUE(match(ConstantStruct::get(StTy2, {U, UP}), m_Undef()));
+  EXPECT_TRUE(match(ConstantStruct::get(StTy2, {P, UP}), m_Undef()));
+  EXPECT_TRUE(match(ConstantStruct::get(StTy3, {UP, U}), m_Undef()));
+  EXPECT_TRUE(match(ConstantStruct::get(StTy3, {UP, P}), m_Undef()));
+
+  EXPECT_FALSE(match(ConstantStruct::get(StTy, {U, Zero}), m_Undef()));
+  EXPECT_FALSE(match(ConstantStruct::get(StTy, {Zero, U}), m_Undef()));
+  EXPECT_FALSE(match(ConstantStruct::get(StTy, {P, Zero}), m_Undef()));
+  EXPECT_FALSE(match(ConstantStruct::get(StTy, {Zero, P}), m_Undef()));
+
+  EXPECT_FALSE(match(ConstantStruct::get(StTy2, {Zero, UP}), m_Undef()));
+  EXPECT_FALSE(match(ConstantStruct::get(StTy3, {UP, Zero}), m_Undef()));
+}
+
 TEST_F(PatternMatchTest, VectorUndefInt) {
   Type *ScalarTy = IRB.getInt8Ty();
   Type *VectorTy = FixedVectorType::get(ScalarTy, 4);