[RISCV] Teach targetShrinkDemandedConstant to preserve (and X, 0xffffffff).

We look for this pattern frequently in isel patterns so its a
good idea to try to preserve it.

This also let's us remove our special isel handling for srliw
and use a direct pattern match of (srl (and X, 0xffffffff), C)
since no bits will be removed from the and mask.

Differential Revision: https://reviews.llvm.org/D99042

GitOrigin-RevId: c40cea6f083a8a67ea950e058e16d37bb04e8c4b
diff --git a/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
index c1ae6de..1e7516e 100644
--- a/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
+++ b/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
@@ -1143,27 +1143,6 @@
   return false;
 }
 
-// Match (srl (and val, mask), imm) where the result would be a
-// zero-extended 32-bit integer. i.e. the mask is 0xffffffff or the result
-// is equivalent to this (SimplifyDemandedBits may have removed lower bits
-// from the mask that aren't necessary due to the right-shifting).
-bool RISCVDAGToDAGISel::MatchSRLIW(SDNode *N) const {
-  assert(N->getOpcode() == ISD::SRL);
-  assert(N->getOperand(0).getOpcode() == ISD::AND);
-  assert(isa<ConstantSDNode>(N->getOperand(1)));
-  assert(isa<ConstantSDNode>(N->getOperand(0).getOperand(1)));
-
-  // The IsRV64 predicate is checked after PatFrag predicates so we can get
-  // here even on RV32.
-  if (!Subtarget->is64Bit())
-    return false;
-
-  SDValue And = N->getOperand(0);
-  uint64_t ShAmt = N->getConstantOperandVal(1);
-  uint64_t Mask = And.getConstantOperandVal(1);
-  return (Mask | maskTrailingOnes<uint64_t>(ShAmt)) == 0xffffffff;
-}
-
 // Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32
 // on RV64).
 // SLLIUW is the same as SLLI except for the fact that it clears the bits
diff --git a/lib/Target/RISCV/RISCVISelDAGToDAG.h b/lib/Target/RISCV/RISCVISelDAGToDAG.h
index 4fa6f54..e83e62d 100644
--- a/lib/Target/RISCV/RISCVISelDAGToDAG.h
+++ b/lib/Target/RISCV/RISCVISelDAGToDAG.h
@@ -57,7 +57,6 @@
   bool selectSExti32(SDValue N, SDValue &Val);
   bool selectZExti32(SDValue N, SDValue &Val);
 
-  bool MatchSRLIW(SDNode *N) const;
   bool MatchSLLIUW(SDNode *N) const;
 
   bool selectVLOp(SDValue N, SDValue &VL);
diff --git a/lib/Target/RISCV/RISCVISelLowering.cpp b/lib/Target/RISCV/RISCVISelLowering.cpp
index c14a6cb..23966e6 100644
--- a/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4954,16 +4954,36 @@
   // Clear all non-demanded bits initially.
   APInt ShrunkMask = Mask & DemandedBits;
 
+  // Try to make a smaller immediate by setting undemanded bits.
+
+  APInt ExpandedMask = Mask | ~DemandedBits;
+
+  auto IsLegalMask = [ShrunkMask, ExpandedMask](const APInt &Mask) -> bool {
+    return ShrunkMask.isSubsetOf(Mask) && Mask.isSubsetOf(ExpandedMask);
+  };
+  auto UseMask = [Mask, Op, VT, &TLO](const APInt &NewMask) -> bool {
+    if (NewMask == Mask)
+      return true;
+    SDLoc DL(Op);
+    SDValue NewC = TLO.DAG.getConstant(NewMask, DL, VT);
+    SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC);
+    return TLO.CombineTo(Op, NewOp);
+  };
+
   // If the shrunk mask fits in sign extended 12 bits, let the target
   // independent code apply it.
   if (ShrunkMask.isSignedIntN(12))
     return false;
 
-  // Try to make a smaller immediate by setting undemanded bits.
+  // Try to preserve (and X, 0xffffffff), the (zext_inreg X, i32) pattern.
+  if (VT == MVT::i64) {
+    APInt NewMask = APInt(64, 0xffffffff);
+    if (IsLegalMask(NewMask))
+      return UseMask(NewMask);
+  }
 
-  // We need to be able to make a negative number through a combination of mask
-  // and undemanded bits.
-  APInt ExpandedMask = Mask | ~DemandedBits;
+  // For the remaining optimizations, we need to be able to make a negative
+  // number through a combination of mask and undemanded bits.
   if (!ExpandedMask.isNegative())
     return false;
 
@@ -4981,18 +5001,8 @@
     return false;
 
   // Sanity check that our new mask is a subset of the demanded mask.
-  assert(NewMask.isSubsetOf(ExpandedMask));
-
-  // If we aren't changing the mask, just return true to keep it and prevent
-  // the caller from optimizing.
-  if (NewMask == Mask)
-    return true;
-
-  // Replace the constant with the new mask.
-  SDLoc DL(Op);
-  SDValue NewC = TLO.DAG.getConstant(NewMask, DL, VT);
-  SDValue NewOp = TLO.DAG.getNode(ISD::AND, DL, VT, Op.getOperand(0), NewC);
-  return TLO.CombineTo(Op, NewOp);
+  assert(IsLegalMask(NewMask));
+  return UseMask(NewMask);
 }
 
 void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
diff --git a/lib/Target/RISCV/RISCVInstrInfo.td b/lib/Target/RISCV/RISCVInstrInfo.td
index 2f73586b..b07204c 100644
--- a/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/lib/Target/RISCV/RISCVInstrInfo.td
@@ -847,11 +847,6 @@
 }]>;
 def zexti32 : ComplexPattern<i64, 1, "selectZExti32">;
 
-def SRLIWPat : PatFrag<(ops node:$A, node:$B),
-                       (srl (and node:$A, imm), node:$B), [{
-  return MatchSRLIW(N);
-}]>;
-
 // Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32
 // on RV64). Also used to optimize the same sequence without SLLIUW.
 def SLLIUWPat : PatFrag<(ops node:$A, node:$B),
@@ -1164,7 +1159,7 @@
           (SUBW GPR:$rs1, GPR:$rs2)>;
 def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32),
           (SLLIW GPR:$rs1, uimm5:$shamt)>;
-def : Pat<(i64 (SRLIWPat GPR:$rs1, uimm5:$shamt)),
+def : Pat<(i64 (srl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)),
           (SRLIW GPR:$rs1, uimm5:$shamt)>;
 def : Pat<(i64 (srl (shl GPR:$rs1, (i64 32)), uimm6gt32:$shamt)),
           (SRLIW GPR:$rs1, (ImmSub32 uimm6gt32:$shamt))>;
diff --git a/lib/Target/RISCV/RISCVInstrInfoB.td b/lib/Target/RISCV/RISCVInstrInfoB.td
index 1b28740..dcb2196 100644
--- a/lib/Target/RISCV/RISCVInstrInfoB.td
+++ b/lib/Target/RISCV/RISCVInstrInfoB.td
@@ -871,6 +871,6 @@
                            i32)),
           (PACKW GPR:$rs1, GPR:$rs2)>;
 def : Pat<(i64 (or (and (assertsexti32 GPR:$rs2), 0xFFFFFFFFFFFF0000),
-                   (SRLIWPat GPR:$rs1, (i64 16)))),
+                   (srl (and GPR:$rs1, 0xFFFFFFFF), (i64 16)))),
           (PACKUW GPR:$rs1, GPR:$rs2)>;
 } // Predicates = [HasStdExtZbp, IsRV64]
diff --git a/test/CodeGen/RISCV/alu32.ll b/test/CodeGen/RISCV/alu32.ll
index fa0c4bb..d9fb08b5 100644
--- a/test/CodeGen/RISCV/alu32.ll
+++ b/test/CodeGen/RISCV/alu32.ll
@@ -129,8 +129,8 @@
   ret i32 %1
 }
 
-; FIXME: This should use srliw on RV64, but SimplifyDemandedBits breaks the
-; (and X, 0xffffffff) that type legalization inserts.
+; This makes sure SimplifyDemandedBits doesn't prevent us from matching SRLIW
+; on RV64.
 define i32 @srli_demandedbits(i32 %0) {
 ; RV32I-LABEL: srli_demandedbits:
 ; RV32I:       # %bb.0:
@@ -140,11 +140,7 @@
 ;
 ; RV64I-LABEL: srli_demandedbits:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    addi a1, zero, 1
-; RV64I-NEXT:    slli a1, a1, 32
-; RV64I-NEXT:    addi a1, a1, -16
-; RV64I-NEXT:    and a0, a0, a1
-; RV64I-NEXT:    srli a0, a0, 3
+; RV64I-NEXT:    srliw a0, a0, 3
 ; RV64I-NEXT:    ori a0, a0, 1
 ; RV64I-NEXT:    ret
   %2 = lshr i32 %0, 3
diff --git a/test/CodeGen/RISCV/rv64zba.ll b/test/CodeGen/RISCV/rv64zba.ll
index 3174ecc..1f25ef3 100644
--- a/test/CodeGen/RISCV/rv64zba.ll
+++ b/test/CodeGen/RISCV/rv64zba.ll
@@ -126,34 +126,26 @@
   ret i64 %and
 }
 
-; FIXME: This can use zext.w, but we need targetShrinkDemandedConstant to
-; to adjust the immediate.
+; This makes sure targetShrinkDemandedConstant changes the and immmediate to
+; allow zext.w or slli+srli.
 define i64 @zextw_demandedbits_i64(i64 %0) {
 ; RV64I-LABEL: zextw_demandedbits_i64:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    addi a1, zero, 1
-; RV64I-NEXT:    slli a1, a1, 32
-; RV64I-NEXT:    addi a1, a1, -2
-; RV64I-NEXT:    and a0, a0, a1
 ; RV64I-NEXT:    ori a0, a0, 1
+; RV64I-NEXT:    slli a0, a0, 32
+; RV64I-NEXT:    srli a0, a0, 32
 ; RV64I-NEXT:    ret
 ;
 ; RV64IB-LABEL: zextw_demandedbits_i64:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    addi a1, zero, 1
-; RV64IB-NEXT:    slli a1, a1, 32
-; RV64IB-NEXT:    addi a1, a1, -2
-; RV64IB-NEXT:    and a0, a0, a1
 ; RV64IB-NEXT:    ori a0, a0, 1
+; RV64IB-NEXT:    zext.w a0, a0
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBA-LABEL: zextw_demandedbits_i64:
 ; RV64IBA:       # %bb.0:
-; RV64IBA-NEXT:    addi a1, zero, 1
-; RV64IBA-NEXT:    slli a1, a1, 32
-; RV64IBA-NEXT:    addi a1, a1, -2
-; RV64IBA-NEXT:    and a0, a0, a1
 ; RV64IBA-NEXT:    ori a0, a0, 1
+; RV64IBA-NEXT:    zext.w a0, a0
 ; RV64IBA-NEXT:    ret
   %2 = and i64 %0, 4294967294
   %3 = or i64 %2, 1