[llvm][RISCV] Support rounding mulh for P extension codegen (#171593)

In p extension spec, rounding is performed by adding 1 << (elt_bits - 1)
to its result.

Stack on: #171581
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 03f1cee..017bb8d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15339,14 +15339,18 @@
   }
   case RISCVISD::PASUB:
   case RISCVISD::PASUBU:
-  case RISCVISD::PMULHSU: {
+  case RISCVISD::PMULHSU:
+  case RISCVISD::PMULHR:
+  case RISCVISD::PMULHRU:
+  case RISCVISD::PMULHRSU: {
     MVT VT = N->getSimpleValueType(0);
     SDValue Op0 = N->getOperand(0);
     SDValue Op1 = N->getOperand(1);
     unsigned Opcode = N->getOpcode();
-    // PMULHSU doesn't support i8 variants
-    assert(VT == MVT::v2i16 ||
-           (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
+    // PMULH* variants don't support i8
+    bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+                  Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+    assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
     MVT NewVT = MVT::v4i16;
     if (VT == MVT::v4i8)
       NewVT = MVT::v8i8;
@@ -16465,6 +16469,7 @@
 // Handle P extension truncate patterns:
 // PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
 // PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
 static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
                                    const RISCVSubtarget &Subtarget) {
   SDValue N0 = N->getOperand(0);
@@ -16494,6 +16499,24 @@
 
   SDValue Op = N0.getOperand(0);
   unsigned ShAmtVal = C->getZExtValue();
+  unsigned EltBits = VecVT.getScalarSizeInBits();
+
+  // Check for rounding pattern: (add (mul ...), round_const)
+  bool IsRounding = false;
+  if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+    SDValue AddRHS = Op.getOperand(1);
+    if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+      if (auto *RndC =
+              dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+        uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+        if (RndC->getZExtValue() == ExpectedRnd &&
+            Op.getOperand(0).getOpcode() == ISD::MUL) {
+          Op = Op.getOperand(0);
+          IsRounding = true;
+        }
+      }
+    }
+  }
 
   SDValue LHS = Op.getOperand(0);
   SDValue RHS = Op.getOperand(1);
@@ -16528,17 +16551,31 @@
       return SDValue();
     break;
   case ISD::MUL:
-    // PMULHSU: shift amount must be element size, only for i16/i32
-    unsigned EltBits = VecVT.getScalarSizeInBits();
+    // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
     if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
       return SDValue();
-    if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
-      Opc = RISCVISD::PMULHSU;
-      // commuted case
-      if (LHSIsZExt && RHSIsSExt)
-        std::swap(A, B);
-    } else
-      return SDValue();
+    if (IsRounding) {
+      if (LHSIsSExt && RHSIsSExt) {
+        Opc = RISCVISD::PMULHR;
+      } else if (LHSIsZExt && RHSIsZExt) {
+        Opc = RISCVISD::PMULHRU;
+      } else if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
+        Opc = RISCVISD::PMULHRSU;
+        // commuted case
+        if (LHSIsZExt && RHSIsSExt)
+          std::swap(A, B);
+      } else {
+        return SDValue();
+      }
+    } else {
+      if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
+        Opc = RISCVISD::PMULHSU;
+        // commuted case
+        if (LHSIsZExt && RHSIsSExt)
+          std::swap(A, B);
+      } else
+        return SDValue();
+    }
     break;
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index d2b122d..f2aeacd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1471,6 +1471,9 @@
 def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
 def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
 def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
 
 let Predicates = [HasStdExtP] in {
   def : PatGpr<abs, ABS>;
@@ -1538,6 +1541,11 @@
   def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
   def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
 
+  // 16-bit multiply high rounding patterns
+  def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
   // 8-bit logical shift left/right patterns
   def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
            (PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1674,6 +1682,11 @@
   def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
   def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
 
+  // 32-bit multiply high rounding patterns
+  def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+  def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
   // 8/16/32-bit multiply low patterns
   def: Pat<(v8i8 (mul GPR:$rs1, GPR:$rs2)),
            (PPAIRE_B (PMUL_H_B00 GPR:$rs1, GPR:$rs2), (PMUL_H_B11 GPR:$rs1, GPR:$rs2))>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index ed467ee..ba9aa18 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -1235,6 +1235,89 @@
   ret void
 }
 
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a1, a2
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = sext <2 x i16> %a to <2 x i32>
+  %b_ext = zext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhrsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lw a1, 0(a1)
+; CHECK-NEXT:    lw a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a2, a1
+; CHECK-NEXT:    sw a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i16>, ptr %a_ptr
+  %b = load <2 x i16>, ptr %b_ptr
+  %a_ext = zext <2 x i16> %a to <2 x i32>
+  %b_ext = sext <2 x i16> %b to <2 x i32>
+  %mul = mul <2 x i32> %a_ext, %b_ext
+  %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+  %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+  %res = trunc <2 x i32> %shift to <2 x i16>
+  store <2 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
 ; Test packed multiply low for v2i16
 define void @test_pmul_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
 ; CHECK-RV32-LABEL: test_pmul_h:
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 1838739..ba04c95 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -1328,6 +1328,169 @@
   ret void
 }
 
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhr.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhru.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhru_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhru.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+; Test packed multiply high rounding signed-unsigned
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = sext <4 x i16> %a to <4 x i32>
+  %b_ext = zext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhrsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.h a1, a2, a1
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <4 x i16>, ptr %a_ptr
+  %b = load <4 x i16>, ptr %b_ptr
+  %a_ext = zext <4 x i16> %a to <4 x i32>
+  %b_ext = sext <4 x i16> %b to <4 x i32>
+  %mul = mul <4 x i32> %a_ext, %b_ext
+  %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+  %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+  %res = trunc <4 x i32> %shift to <4 x i16>
+  store <4 x i16> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhrsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_w:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.w a1, a1, a2
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = sext <2 x i32> %a to <2 x i64>
+  %b_ext = zext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
+define void @test_pmulhrsu_w_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_w_commuted:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ld a1, 0(a1)
+; CHECK-NEXT:    ld a2, 0(a2)
+; CHECK-NEXT:    pmulhrsu.w a1, a2, a1
+; CHECK-NEXT:    sd a1, 0(a0)
+; CHECK-NEXT:    ret
+  %a = load <2 x i32>, ptr %a_ptr
+  %b = load <2 x i32>, ptr %b_ptr
+  %a_ext = zext <2 x i32> %a to <2 x i64>
+  %b_ext = sext <2 x i32> %b to <2 x i64>
+  %mul = mul <2 x i64> %a_ext, %b_ext
+  %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+  %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+  %res = trunc <2 x i64> %shift to <2 x i32>
+  store <2 x i32> %res, ptr %ret_ptr
+  ret void
+}
+
 ; Test packed multiply low for v4i16
 define void @test_pmul_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
 ; CHECK-LABEL: test_pmul_h: