[DAG] Constant fold ISD::FSHL/FSHR nodes (#154480)

Fixes #153612.
This patch handles trinary scalar integers for FSHL/R in
`FoldConstantArithmetic`.
Pending until #153790 is merged.
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 15d7e76..27b5a0d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -11281,6 +11281,11 @@
   unsigned BitWidth = VT.getScalarSizeInBits();
   SDLoc DL(N);
 
+  // fold (fshl/fshr C0, C1, C2) -> C3
+  if (SDValue C =
+          DAG.FoldConstantArithmetic(N->getOpcode(), DL, VT, {N0, N1, N2}))
+    return C;
+
   // fold (fshl N0, N1, 0) -> N0
   // fold (fshr N0, N1, 0) -> N1
   if (isPowerOf2_32(BitWidth))
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 531297b..3672a91 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -7175,6 +7175,45 @@
     }
   }
 
+  // Handle fshl/fshr special cases.
+  if (Opcode == ISD::FSHL || Opcode == ISD::FSHR) {
+    auto *C1 = dyn_cast<ConstantSDNode>(Ops[0]);
+    auto *C2 = dyn_cast<ConstantSDNode>(Ops[1]);
+    auto *C3 = dyn_cast<ConstantSDNode>(Ops[2]);
+
+    if (C1 && C2 && C3) {
+      if (C1->isOpaque() || C2->isOpaque() || C3->isOpaque())
+        return SDValue();
+      const APInt &V1 = C1->getAPIntValue(), &V2 = C2->getAPIntValue(),
+                  &V3 = C3->getAPIntValue();
+
+      APInt FoldedVal = Opcode == ISD::FSHL ? APIntOps::fshl(V1, V2, V3)
+                                            : APIntOps::fshr(V1, V2, V3);
+      return getConstant(FoldedVal, DL, VT);
+    }
+  }
+
+  // Handle fma/fmad special cases.
+  if (Opcode == ISD::FMA || Opcode == ISD::FMAD) {
+    assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
+    assert(Ops[0].getValueType() == VT && Ops[1].getValueType() == VT &&
+           Ops[2].getValueType() == VT && "FMA types must match!");
+    ConstantFPSDNode *C1 = dyn_cast<ConstantFPSDNode>(Ops[0]);
+    ConstantFPSDNode *C2 = dyn_cast<ConstantFPSDNode>(Ops[1]);
+    ConstantFPSDNode *C3 = dyn_cast<ConstantFPSDNode>(Ops[2]);
+    if (C1 && C2 && C3) {
+      APFloat V1 = C1->getValueAPF();
+      const APFloat &V2 = C2->getValueAPF();
+      const APFloat &V3 = C3->getValueAPF();
+      if (Opcode == ISD::FMAD) {
+        V1.multiply(V2, APFloat::rmNearestTiesToEven);
+        V1.add(V3, APFloat::rmNearestTiesToEven);
+      } else
+        V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
+      return getConstantFP(V1, DL, VT);
+    }
+  }
+
   // This is for vector folding only from here on.
   if (!VT.isVector())
     return SDValue();
@@ -8137,27 +8176,6 @@
          "Operand is DELETED_NODE!");
   // Perform various simplifications.
   switch (Opcode) {
-  case ISD::FMA:
-  case ISD::FMAD: {
-    assert(VT.isFloatingPoint() && "This operator only applies to FP types!");
-    assert(N1.getValueType() == VT && N2.getValueType() == VT &&
-           N3.getValueType() == VT && "FMA types must match!");
-    ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
-    ConstantFPSDNode *N2CFP = dyn_cast<ConstantFPSDNode>(N2);
-    ConstantFPSDNode *N3CFP = dyn_cast<ConstantFPSDNode>(N3);
-    if (N1CFP && N2CFP && N3CFP) {
-      APFloat  V1 = N1CFP->getValueAPF();
-      const APFloat &V2 = N2CFP->getValueAPF();
-      const APFloat &V3 = N3CFP->getValueAPF();
-      if (Opcode == ISD::FMAD) {
-        V1.multiply(V2, APFloat::rmNearestTiesToEven);
-        V1.add(V3, APFloat::rmNearestTiesToEven);
-      } else
-        V1.fusedMultiplyAdd(V2, V3, APFloat::rmNearestTiesToEven);
-      return getConstantFP(V1, DL, VT);
-    }
-    break;
-  }
   case ISD::BUILD_VECTOR: {
     // Attempt to simplify BUILD_VECTOR.
     SDValue Ops[] = {N1, N2, N3};
@@ -8183,12 +8201,6 @@
     // Use FoldSetCC to simplify SETCC's.
     if (SDValue V = FoldSetCC(VT, N1, N2, cast<CondCodeSDNode>(N3)->get(), DL))
       return V;
-    // Vector constant folding.
-    SDValue Ops[] = {N1, N2, N3};
-    if (SDValue V = FoldConstantArithmetic(Opcode, DL, VT, Ops)) {
-      NewSDValueDbgMsg(V, "New node vector constant folding: ", this);
-      return V;
-    }
     break;
   }
   case ISD::SELECT:
@@ -8324,6 +8336,19 @@
   }
   }
 
+  // Perform trivial constant folding for arithmetic operators.
+  switch (Opcode) {
+  case ISD::FMA:
+  case ISD::FMAD:
+  case ISD::SETCC:
+  case ISD::FSHL:
+  case ISD::FSHR:
+    if (SDValue SV =
+            FoldConstantArithmetic(Opcode, DL, VT, {N1, N2, N3}, Flags))
+      return SV;
+    break;
+  }
+
   // Memoize node if it doesn't produce a glue result.
   SDNode *N;
   SDVTList VTs = getVTList(VT);
diff --git a/llvm/test/CodeGen/X86/fshl-fshr-constant.ll b/llvm/test/CodeGen/X86/fshl-fshr-constant.ll
new file mode 100644
index 0000000..fdc34f5
--- /dev/null
+++ b/llvm/test/CodeGen/X86/fshl-fshr-constant.ll
@@ -0,0 +1,149 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefixes=CHECK,CHECK-EXPAND
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512vl,+avx512vbmi2 | FileCheck %s --check-prefixes=CHECK,CHECK-UNEXPAND
+
+define <4 x i32> @test_fshl_constants() {
+; CHECK-EXPAND-LABEL: test_fshl_constants:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vmovaps {{.*#+}} xmm0 = [0,512,2048,6144]
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshl_constants:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpmovsxwd {{.*#+}} xmm0 = [0,512,2048,6144]
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> <i32 0, i32 1, i32 2, i32 3>, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshl_splat_constants() {
+; CHECK-LABEL: test_fshl_splat_constants:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vbroadcastss {{.*#+}} xmm0 = [256,256,256,256]
+; CHECK-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 4, i32 4, i32 4, i32 4>, <4 x i32> <i32 8, i32 8, i32 8, i32 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshl_two_constants(<4 x i32> %a) {
+; CHECK-EXPAND-LABEL: test_fshl_two_constants:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshl_two_constants:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpmovsxbd {{.*#+}} xmm1 = [4,5,6,7]
+; CHECK-UNEXPAND-NEXT:    vpshldvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshl_one_constant(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-EXPAND-LABEL: test_fshl_one_constant:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpsrlvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; CHECK-EXPAND-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshl_one_constant:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpshldvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshl_none_constant(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
+; CHECK-EXPAND-LABEL: test_fshl_none_constant:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpbroadcastd {{.*#+}} xmm3 = [31,31,31,31]
+; CHECK-EXPAND-NEXT:    vpandn %xmm3, %xmm2, %xmm4
+; CHECK-EXPAND-NEXT:    vpsrld $1, %xmm1, %xmm1
+; CHECK-EXPAND-NEXT:    vpsrlvd %xmm4, %xmm1, %xmm1
+; CHECK-EXPAND-NEXT:    vpand %xmm3, %xmm2, %xmm2
+; CHECK-EXPAND-NEXT:    vpsllvd %xmm2, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshl_none_constant:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpshldvd %xmm2, %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshl.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshr_constants() {
+; CHECK-LABEL: test_fshr_constants:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vmovaps {{.*#+}} xmm0 = [0,8388608,8388608,6291456]
+; CHECK-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> <i32 0, i32 1, i32 2, i32 3>, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshr_two_constants(<4 x i32> %a) {
+; CHECK-EXPAND-LABEL: test_fshr_two_constants:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshr_two_constants:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpmovsxbd {{.*#+}} xmm1 = [4,5,6,7]
+; CHECK-UNEXPAND-NEXT:    vpshrdvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
+; CHECK-UNEXPAND-NEXT:    vmovdqa %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> <i32 4, i32 5, i32 6, i32 7>, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshr_one_constant(<4 x i32> %a, <4 x i32> %b) {
+; CHECK-EXPAND-LABEL: test_fshr_one_constant:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpsrlvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
+; CHECK-EXPAND-NEXT:    vpsllvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshr_one_constant:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpshrdvd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm1
+; CHECK-UNEXPAND-NEXT:    vmovdqa %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> <i32 8, i32 9, i32 10, i32 11>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshr_none_constant(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c) {
+; CHECK-EXPAND-LABEL: test_fshr_none_constant:
+; CHECK-EXPAND:       # %bb.0:
+; CHECK-EXPAND-NEXT:    vpbroadcastd {{.*#+}} xmm3 = [31,31,31,31]
+; CHECK-EXPAND-NEXT:    vpand %xmm3, %xmm2, %xmm4
+; CHECK-EXPAND-NEXT:    vpsrlvd %xmm4, %xmm1, %xmm1
+; CHECK-EXPAND-NEXT:    vpandn %xmm3, %xmm2, %xmm2
+; CHECK-EXPAND-NEXT:    vpaddd %xmm0, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    vpsllvd %xmm2, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    vpor %xmm1, %xmm0, %xmm0
+; CHECK-EXPAND-NEXT:    retq
+;
+; CHECK-UNEXPAND-LABEL: test_fshr_none_constant:
+; CHECK-UNEXPAND:       # %bb.0:
+; CHECK-UNEXPAND-NEXT:    vpshrdvd %xmm2, %xmm0, %xmm1
+; CHECK-UNEXPAND-NEXT:    vmovdqa %xmm1, %xmm0
+; CHECK-UNEXPAND-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> %a, <4 x i32> %b, <4 x i32> %c)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @test_fshr_splat_constants() {
+; CHECK-LABEL: test_fshr_splat_constants:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vbroadcastss {{.*#+}} xmm0 = [16777216,16777216,16777216,16777216]
+; CHECK-NEXT:    retq
+  %res = call <4 x i32> @llvm.fshr.v4i32(<4 x i32> <i32 1, i32 1, i32 1, i32 1>, <4 x i32> <i32 4, i32 4, i32 4, i32 4>, <4 x i32> <i32 8, i32 8, i32 8, i32 8>)
+  ret <4 x i32> %res
+}