[RISCV] Relax another one use restriction in performSRACombine.

When folding (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), i32), C)
it's possible that the add is used by multiple sras. We should
allow the combine if all the SRAs will eventually be updated.

After transforming all of the sras, the shls will share a single
(sext_inreg (add X, C1), i32).

This pattern occurs if an sra with 32 is used as index in multiple
GEPs with different scales. The shl from the GEPs will be combined
with the sra before we get a chance to match the sra pattern.

GitOrigin-RevId: 12a1ca9c42c45cfb4777a42f73db5d33e87577e4
diff --git a/lib/Target/RISCV/RISCVISelLowering.cpp b/lib/Target/RISCV/RISCVISelLowering.cpp
index 227f2a1..fca42d7 100644
--- a/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -8874,8 +8874,6 @@
   // We might have an ADD or SUB between the SRA and SHL.
   bool IsAdd = N0.getOpcode() == ISD::ADD;
   if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
-    if (!N0.hasOneUse())
-      return SDValue();
     // Other operand needs to be a constant we can modify.
     AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
     if (!AddC)
@@ -8885,6 +8883,16 @@
     if (AddC->getAPIntValue().countTrailingZeros() < 32)
       return SDValue();
 
+    // All users should be a shift by constant less than or equal to 32. This
+    // ensures we'll do this optimization for each of them to produce an
+    // add/sub+sext_inreg they can all share.
+    for (SDNode *U : N0->uses()) {
+      if (U->getOpcode() != ISD::SRA ||
+          !isa<ConstantSDNode>(U->getOperand(1)) ||
+          cast<ConstantSDNode>(U->getOperand(1))->getZExtValue() > 32)
+        return SDValue();
+    }
+
     Shl = N0.getOperand(IsAdd ? 0 : 1);
   } else {
     // Not an ADD or SUB.
diff --git a/test/CodeGen/RISCV/rv64i-shift-sext.ll b/test/CodeGen/RISCV/rv64i-shift-sext.ll
index ad1df83..55620af 100644
--- a/test/CodeGen/RISCV/rv64i-shift-sext.ll
+++ b/test/CodeGen/RISCV/rv64i-shift-sext.ll
@@ -196,3 +196,27 @@
   %12 = add i8 %7, %11
   ret i8 %12
 }
+
+define signext i32 @test14(i8* %0, i32* %1, i64 %2) {
+; RV64I-LABEL: test14:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    li a3, 1
+; RV64I-NEXT:    subw a2, a3, a2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lbu a0, 0(a0)
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a1, a1, a2
+; RV64I-NEXT:    lw a1, 0(a1)
+; RV64I-NEXT:    addw a0, a0, a1
+; RV64I-NEXT:    ret
+  %4 = mul i64 %2, -4294967296
+  %5 = add i64 %4, 4294967296 ; 1 << 32
+  %6 = ashr exact i64 %5, 32
+  %7 = getelementptr inbounds i8, i8* %0, i64 %6
+  %8 = load i8, i8* %7, align 4
+  %9 = zext i8 %8 to i32
+  %10 = getelementptr inbounds i32, i32* %1, i64 %6
+  %11 = load i32, i32* %10, align 4
+  %12 = add i32 %9, %11
+  ret i32 %12
+}