[GlobalISel] Add matchers for constant splat.

This change exposes isBuildVectorConstantSplat() to the llvm namespace
and uses it to implement the constant splat versions of
m_SpecificICst().

CombinerHelper::matchOrShiftToFunnelShift() can now work with vector
types and CombinerHelper::matchMulOBy2()'s match for a constant splat is
simplified.

Differential Revision: https://reviews.llvm.org/D114625
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
index e813d03..84aad98 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h
@@ -129,6 +129,43 @@
   return SpecificConstantMatch(RequestedValue);
 }
 
+/// Matcher for a specific constant splat.
+struct SpecificConstantSplatMatch {
+  int64_t RequestedVal;
+  SpecificConstantSplatMatch(int64_t RequestedVal)
+      : RequestedVal(RequestedVal) {}
+  bool match(const MachineRegisterInfo &MRI, Register Reg) {
+    return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
+                                      /* AllowUndef */ false);
+  }
+};
+
+/// Matches a constant splat of \p RequestedValue.
+inline SpecificConstantSplatMatch m_SpecificICstSplat(int64_t RequestedValue) {
+  return SpecificConstantSplatMatch(RequestedValue);
+}
+
+/// Matcher for a specific constant or constant splat.
+struct SpecificConstantOrSplatMatch {
+  int64_t RequestedVal;
+  SpecificConstantOrSplatMatch(int64_t RequestedVal)
+      : RequestedVal(RequestedVal) {}
+  bool match(const MachineRegisterInfo &MRI, Register Reg) {
+    int64_t MatchedVal;
+    if (mi_match(Reg, MRI, m_ICst(MatchedVal)) && MatchedVal == RequestedVal)
+      return true;
+    return isBuildVectorConstantSplat(Reg, MRI, RequestedVal,
+                                      /* AllowUndef */ false);
+  }
+};
+
+/// Matches a \p RequestedValue constant or a constant splat of \p
+/// RequestedValue.
+inline SpecificConstantOrSplatMatch
+m_SpecificICstOrSplat(int64_t RequestedValue) {
+  return SpecificConstantOrSplatMatch(RequestedValue);
+}
+
 ///{
 /// Convenience matchers for specific integer values.
 inline SpecificConstantMatch m_ZeroInt() { return SpecificConstantMatch(0); }
diff --git a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
index 86545b9..4126e2a 100644
--- a/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
+++ b/llvm/include/llvm/CodeGen/GlobalISel/Utils.h
@@ -378,6 +378,18 @@
                                            const MachineRegisterInfo &MRI,
                                            bool AllowUndef = true);
 
+/// Return true if the specified register is defined by G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+bool isBuildVectorConstantSplat(const Register Reg,
+                                const MachineRegisterInfo &MRI,
+                                int64_t SplatValue, bool AllowUndef);
+
+/// Return true if the specified instruction is a G_BUILD_VECTOR or
+/// G_BUILD_VECTOR_TRUNC where all of the elements are \p SplatValue or undef.
+bool isBuildVectorConstantSplat(const MachineInstr &MI,
+                                const MachineRegisterInfo &MRI,
+                                int64_t SplatValue, bool AllowUndef);
+
 /// Return true if the specified instruction is a G_BUILD_VECTOR or
 /// G_BUILD_VECTOR_TRUNC where all of the elements are 0 or undef.
 bool isBuildVectorAllZeros(const MachineInstr &MI,
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index b3dee82..755b3b8 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -3878,21 +3878,21 @@
   Register ShlSrc, ShlAmt, LShrSrc, LShrAmt;
   unsigned FshOpc = 0;
 
-  // TODO: Handle vector types.
   // Match (or (shl x, amt), (lshr y, sub(bw, amt))).
-  if (mi_match(Dst, MRI,
-               // m_GOr() handles the commuted version as well.
-               m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
-                     m_GLShr(m_Reg(LShrSrc), m_GSub(m_SpecificICst(BitWidth),
-                                                    m_Reg(LShrAmt)))))) {
+  if (mi_match(
+          Dst, MRI,
+          // m_GOr() handles the commuted version as well.
+          m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)),
+                m_GLShr(m_Reg(LShrSrc), m_GSub(m_SpecificICstOrSplat(BitWidth),
+                                               m_Reg(LShrAmt)))))) {
     FshOpc = TargetOpcode::G_FSHL;
 
     // Match (or (shl x, sub(bw, amt)), (lshr y, amt)).
-  } else if (mi_match(
-                 Dst, MRI,
-                 m_GOr(m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)),
-                       m_GShl(m_Reg(ShlSrc), m_GSub(m_SpecificICst(BitWidth),
-                                                    m_Reg(ShlAmt)))))) {
+  } else if (mi_match(Dst, MRI,
+                      m_GOr(m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt)),
+                            m_GShl(m_Reg(ShlSrc),
+                                   m_GSub(m_SpecificICstOrSplat(BitWidth),
+                                          m_Reg(ShlAmt)))))) {
     FshOpc = TargetOpcode::G_FSHR;
 
   } else {
@@ -4543,20 +4543,9 @@
 bool CombinerHelper::matchMulOBy2(MachineInstr &MI, BuildFnTy &MatchInfo) {
   unsigned Opc = MI.getOpcode();
   assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO);
-  // Check for a constant 2 or a splat of 2 on the RHS.
-  auto RHS = MI.getOperand(3).getReg();
-  bool IsVector = MRI.getType(RHS).isVector();
-  if (!IsVector && !mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(2)))
+
+  if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2)))
     return false;
-  if (IsVector) {
-    // FIXME: There's no mi_match pattern for this yet.
-    auto *RHSDef = getDefIgnoringCopies(RHS, MRI);
-    if (!RHSDef)
-      return false;
-    auto Splat = getBuildVectorConstantSplat(*RHSDef, MRI);
-    if (!Splat || *Splat != 2)
-      return false;
-  }
 
   MatchInfo = [=, &MI](MachineIRBuilder &B) {
     Observer.changingInstr(MI);
diff --git a/llvm/lib/CodeGen/GlobalISel/Utils.cpp b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
index 6d2925f..b0b8476 100644
--- a/llvm/lib/CodeGen/GlobalISel/Utils.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/Utils.cpp
@@ -1030,16 +1030,22 @@
   return SplatValAndReg;
 }
 
-bool isBuildVectorConstantSplat(const MachineInstr &MI,
-                                const MachineRegisterInfo &MRI,
-                                int64_t SplatValue, bool AllowUndef) {
-  if (auto SplatValAndReg =
-          getAnyConstantSplat(MI.getOperand(0).getReg(), MRI, AllowUndef))
+} // end anonymous namespace
+
+bool llvm::isBuildVectorConstantSplat(const Register Reg,
+                                      const MachineRegisterInfo &MRI,
+                                      int64_t SplatValue, bool AllowUndef) {
+  if (auto SplatValAndReg = getAnyConstantSplat(Reg, MRI, AllowUndef))
     return mi_match(SplatValAndReg->VReg, MRI, m_SpecificICst(SplatValue));
   return false;
 }
 
-} // end anonymous namespace
+bool llvm::isBuildVectorConstantSplat(const MachineInstr &MI,
+                                      const MachineRegisterInfo &MRI,
+                                      int64_t SplatValue, bool AllowUndef) {
+  return isBuildVectorConstantSplat(MI.getOperand(0).getReg(), MRI, SplatValue,
+                                    AllowUndef);
+}
 
 Optional<int64_t>
 llvm::getBuildVectorConstantSplat(const MachineInstr &MI,
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
index 04907c8d..0e2816c 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-fsh.mir
@@ -28,6 +28,33 @@
 ...
 
 ---
+name: fshl_v2i32
+tracksRegLiveness: true
+body: |
+  bb.0:
+    liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5, $vgpr6_vgpr7
+
+    ; CHECK-LABEL: name: fshl_v2i32
+    ; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5, $vgpr6_vgpr7
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
+    ; CHECK-NEXT: %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
+    ; CHECK-NEXT: %amt:_(<2 x s32>) = COPY $vgpr4_vgpr5
+    ; CHECK-NEXT: %or:_(<2 x s32>) = G_FSHL %a, %b, %amt(<2 x s32>)
+    ; CHECK-NEXT: $vgpr6_vgpr7 = COPY %or(<2 x s32>)
+    %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
+    %b:_(<2 x s32>) = COPY $vgpr2_vgpr3
+    %amt:_(<2 x s32>) = COPY $vgpr4_vgpr5
+    %scalar_bw:_(s32) = G_CONSTANT i32 32
+    %bw:_(<2 x s32>) = G_BUILD_VECTOR %scalar_bw(s32), %scalar_bw(s32)
+    %shl:_(<2 x s32>) = G_SHL %a:_, %amt:_(<2 x s32>)
+    %sub:_(<2 x s32>) = G_SUB %bw:_, %amt:_
+    %lshr:_(<2 x s32>) = G_LSHR %b:_, %sub:_(<2 x s32>)
+    %or:_(<2 x s32>) = G_OR %shl:_, %lshr:_
+    $vgpr6_vgpr7 = COPY %or
+...
+
+---
 name: fshl_commute_i32
 tracksRegLiveness: true
 body: |
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
index 5db8df9..6079184 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/combine-rot.mir
@@ -26,6 +26,31 @@
 ...
 
 ---
+name: rotl_v2i32
+tracksRegLiveness: true
+body: |
+  bb.0:
+    liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5
+
+    ; CHECK-LABEL: name: rotl_v2i32
+    ; CHECK: liveins: $vgpr0_vgpr1, $vgpr2_vgpr3, $vgpr4_vgpr5
+    ; CHECK-NEXT: {{  $}}
+    ; CHECK-NEXT: %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
+    ; CHECK-NEXT: %amt:_(<2 x s32>) = COPY $vgpr2_vgpr3
+    ; CHECK-NEXT: %or:_(<2 x s32>) = G_ROTL %a, %amt(<2 x s32>)
+    ; CHECK-NEXT: $vgpr4_vgpr5 = COPY %or(<2 x s32>)
+    %a:_(<2 x s32>) = COPY $vgpr0_vgpr1
+    %amt:_(<2 x s32>) = COPY $vgpr2_vgpr3
+    %scalar_bw:_(s32) = G_CONSTANT i32 32
+    %bw:_(<2 x s32>) = G_BUILD_VECTOR %scalar_bw(s32), %scalar_bw(s32)
+    %shl:_(<2 x s32>) = G_SHL %a:_, %amt:_(<2 x s32>)
+    %sub:_(<2 x s32>) = G_SUB %bw:_, %amt:_
+    %lshr:_(<2 x s32>) = G_LSHR %a:_, %sub:_(<2 x s32>)
+    %or:_(<2 x s32>) = G_OR %shl:_, %lshr:_
+    $vgpr4_vgpr5 = COPY %or
+...
+
+---
 name: rotl_commute_i32
 tracksRegLiveness: true
 body: |
@@ -55,6 +80,7 @@
 body: |
   bb.0:
     liveins: $vgpr0, $vgpr1, $vgpr2
+
     ; CHECK-LABEL: name: rotr_i32
     ; CHECK: liveins: $vgpr0, $vgpr1, $vgpr2
     ; CHECK-NEXT: {{  $}}
diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
index da56cb2..da0aee6 100644
--- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp
@@ -533,6 +533,67 @@
   EXPECT_FALSE(mi_match(MIBAdd.getReg(1), *MRI, m_SpecificICst(42)));
 }
 
+TEST_F(AArch64GISelMITest, MatchSpecificConstantSplat) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT s64 = LLT::scalar(64);
+  LLT v4s64 = LLT::fixed_vector(4, s64);
+
+  MachineInstrBuilder FortyTwoSplat =
+      B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
+  MachineInstrBuilder FortyTwo = B.buildConstant(s64, 42);
+
+  EXPECT_TRUE(mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(42)));
+  EXPECT_FALSE(
+      mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstSplat(43)));
+  EXPECT_FALSE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstSplat(42)));
+
+  MachineInstrBuilder NonConstantSplat =
+      B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
+
+  MachineInstrBuilder AddSplat =
+      B.buildAdd(v4s64, NonConstantSplat, FortyTwoSplat);
+  EXPECT_TRUE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(42)));
+  EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstSplat(43)));
+  EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstSplat(42)));
+
+  MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
+  EXPECT_FALSE(mi_match(Add.getReg(2), *MRI, m_SpecificICstSplat(42)));
+}
+
+TEST_F(AArch64GISelMITest, MatchSpecificConstantOrSplat) {
+  setUp();
+  if (!TM)
+    return;
+
+  LLT s64 = LLT::scalar(64);
+  LLT v4s64 = LLT::fixed_vector(4, s64);
+
+  MachineInstrBuilder FortyTwoSplat =
+      B.buildSplatVector(v4s64, B.buildConstant(s64, 42));
+  MachineInstrBuilder FortyTwo = B.buildConstant(s64, 42);
+
+  EXPECT_TRUE(
+      mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
+  EXPECT_FALSE(
+      mi_match(FortyTwoSplat.getReg(0), *MRI, m_SpecificICstOrSplat(43)));
+  EXPECT_TRUE(mi_match(FortyTwo.getReg(0), *MRI, m_SpecificICstOrSplat(42)));
+
+  MachineInstrBuilder NonConstantSplat =
+      B.buildBuildVector(v4s64, {Copies[0], Copies[0], Copies[0], Copies[0]});
+
+  MachineInstrBuilder AddSplat =
+      B.buildAdd(v4s64, NonConstantSplat, FortyTwoSplat);
+  EXPECT_TRUE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
+  EXPECT_FALSE(mi_match(AddSplat.getReg(2), *MRI, m_SpecificICstOrSplat(43)));
+  EXPECT_FALSE(mi_match(AddSplat.getReg(1), *MRI, m_SpecificICstOrSplat(42)));
+
+  MachineInstrBuilder Add = B.buildAdd(s64, Copies[0], FortyTwo);
+  EXPECT_TRUE(mi_match(Add.getReg(2), *MRI, m_SpecificICstOrSplat(42)));
+}
+
 TEST_F(AArch64GISelMITest, MatchZeroInt) {
   setUp();
   if (!TM)