[llvm][RISCV] Support rounding mulh for P extension codegen (#171593)
In p extension spec, rounding is performed by adding 1 << (elt_bits - 1)
to its result.
Stack on: #171581
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 03f1cee..017bb8d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -15339,14 +15339,18 @@
}
case RISCVISD::PASUB:
case RISCVISD::PASUBU:
- case RISCVISD::PMULHSU: {
+ case RISCVISD::PMULHSU:
+ case RISCVISD::PMULHR:
+ case RISCVISD::PMULHRU:
+ case RISCVISD::PMULHRSU: {
MVT VT = N->getSimpleValueType(0);
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
unsigned Opcode = N->getOpcode();
- // PMULHSU doesn't support i8 variants
- assert(VT == MVT::v2i16 ||
- (Opcode != RISCVISD::PMULHSU && VT == MVT::v4i8));
+ // PMULH* variants don't support i8
+ bool IsMulH = Opcode == RISCVISD::PMULHSU || Opcode == RISCVISD::PMULHR ||
+ Opcode == RISCVISD::PMULHRU || Opcode == RISCVISD::PMULHRSU;
+ assert(VT == MVT::v2i16 || (!IsMulH && VT == MVT::v4i8));
MVT NewVT = MVT::v4i16;
if (VT == MVT::v4i8)
NewVT = MVT::v8i8;
@@ -16465,6 +16469,7 @@
// Handle P extension truncate patterns:
// PASUB/PASUBU: (trunc (srl (sub ([s|z]ext a), ([s|z]ext b)), 1))
// PMULHSU: (trunc (srl (mul (sext a), (zext b)), EltBits))
+// PMULHR*: (trunc (srl (add (mul (sext a), (zext b)), round_const), EltBits))
static SDValue combinePExtTruncate(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
SDValue N0 = N->getOperand(0);
@@ -16494,6 +16499,24 @@
SDValue Op = N0.getOperand(0);
unsigned ShAmtVal = C->getZExtValue();
+ unsigned EltBits = VecVT.getScalarSizeInBits();
+
+ // Check for rounding pattern: (add (mul ...), round_const)
+ bool IsRounding = false;
+ if (Op.getOpcode() == ISD::ADD && (EltBits == 16 || EltBits == 32)) {
+ SDValue AddRHS = Op.getOperand(1);
+ if (auto *RndBV = dyn_cast<BuildVectorSDNode>(AddRHS.getNode())) {
+ if (auto *RndC =
+ dyn_cast_or_null<ConstantSDNode>(RndBV->getSplatValue())) {
+ uint64_t ExpectedRnd = 1ULL << (EltBits - 1);
+ if (RndC->getZExtValue() == ExpectedRnd &&
+ Op.getOperand(0).getOpcode() == ISD::MUL) {
+ Op = Op.getOperand(0);
+ IsRounding = true;
+ }
+ }
+ }
+ }
SDValue LHS = Op.getOperand(0);
SDValue RHS = Op.getOperand(1);
@@ -16528,17 +16551,31 @@
return SDValue();
break;
case ISD::MUL:
- // PMULHSU: shift amount must be element size, only for i16/i32
- unsigned EltBits = VecVT.getScalarSizeInBits();
+ // PMULH*/PMULHR*: shift amount must be element size, only for i16/i32
if (ShAmtVal != EltBits || (EltBits != 16 && EltBits != 32))
return SDValue();
- if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
- Opc = RISCVISD::PMULHSU;
- // commuted case
- if (LHSIsZExt && RHSIsSExt)
- std::swap(A, B);
- } else
- return SDValue();
+ if (IsRounding) {
+ if (LHSIsSExt && RHSIsSExt) {
+ Opc = RISCVISD::PMULHR;
+ } else if (LHSIsZExt && RHSIsZExt) {
+ Opc = RISCVISD::PMULHRU;
+ } else if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
+ Opc = RISCVISD::PMULHRSU;
+ // commuted case
+ if (LHSIsZExt && RHSIsSExt)
+ std::swap(A, B);
+ } else {
+ return SDValue();
+ }
+ } else {
+ if ((LHSIsSExt && RHSIsZExt) || (LHSIsZExt && RHSIsSExt)) {
+ Opc = RISCVISD::PMULHSU;
+ // commuted case
+ if (LHSIsZExt && RHSIsSExt)
+ std::swap(A, B);
+ } else
+ return SDValue();
+ }
break;
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
index d2b122d..f2aeacd 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoP.td
@@ -1471,6 +1471,9 @@
def riscv_pasub : RVSDNode<"PASUB", SDT_RISCVPBinOp>;
def riscv_pasubu : RVSDNode<"PASUBU", SDT_RISCVPBinOp>;
def riscv_pmulhsu : RVSDNode<"PMULHSU", SDT_RISCVPBinOp>;
+def riscv_pmulhr : RVSDNode<"PMULHR", SDT_RISCVPBinOp>;
+def riscv_pmulhru : RVSDNode<"PMULHRU", SDT_RISCVPBinOp>;
+def riscv_pmulhrsu : RVSDNode<"PMULHRSU", SDT_RISCVPBinOp>;
let Predicates = [HasStdExtP] in {
def : PatGpr<abs, ABS>;
@@ -1538,6 +1541,11 @@
def: Pat<(XLenVecI16VT (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_H GPR:$rs1, GPR:$rs2)>;
def: Pat<(XLenVecI16VT (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_H GPR:$rs1, GPR:$rs2)>;
+ // 16-bit multiply high rounding patterns
+ def: Pat<(XLenVecI16VT (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_H GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(XLenVecI16VT (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_H GPR:$rs1, GPR:$rs2)>;
+
// 8-bit logical shift left/right patterns
def: Pat<(XLenVecI8VT (shl GPR:$rs1, (XLenVecI8VT (splat_vector uimm3:$shamt)))),
(PSLLI_B GPR:$rs1, uimm3:$shamt)>;
@@ -1674,6 +1682,11 @@
def: Pat<(v2i32 (mulhu GPR:$rs1, GPR:$rs2)), (PMULHU_W GPR:$rs1, GPR:$rs2)>;
def: Pat<(v2i32 (riscv_pmulhsu GPR:$rs1, GPR:$rs2)), (PMULHSU_W GPR:$rs1, GPR:$rs2)>;
+ // 32-bit multiply high rounding patterns
+ def: Pat<(v2i32 (riscv_pmulhr GPR:$rs1, GPR:$rs2)), (PMULHR_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhru GPR:$rs1, GPR:$rs2)), (PMULHRU_W GPR:$rs1, GPR:$rs2)>;
+ def: Pat<(v2i32 (riscv_pmulhrsu GPR:$rs1, GPR:$rs2)), (PMULHRSU_W GPR:$rs1, GPR:$rs2)>;
+
// 8/16/32-bit multiply low patterns
def: Pat<(v8i8 (mul GPR:$rs1, GPR:$rs2)),
(PPAIRE_B (PMUL_H_B00 GPR:$rs1, GPR:$rs2), (PMUL_H_B11 GPR:$rs1, GPR:$rs2))>;
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
index ed467ee..ba9aa18 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv32.ll
@@ -1235,6 +1235,89 @@
ret void
}
+; Test packed multiply high rounding signed for v2i16
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned for v2i16
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed-unsigned for v2i16
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a1, a2
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = sext <2 x i16> %a to <2 x i32>
+ %b_ext = zext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhrsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h_commuted:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lw a1, 0(a1)
+; CHECK-NEXT: lw a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a2, a1
+; CHECK-NEXT: sw a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i16>, ptr %a_ptr
+ %b = load <2 x i16>, ptr %b_ptr
+ %a_ext = zext <2 x i16> %a to <2 x i32>
+ %b_ext = sext <2 x i16> %b to <2 x i32>
+ %mul = mul <2 x i32> %a_ext, %b_ext
+ %add = add <2 x i32> %mul, <i32 32768, i32 32768>
+ %shift = lshr <2 x i32> %add, <i32 16, i32 16>
+ %res = trunc <2 x i32> %shift to <2 x i16>
+ store <2 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
; Test packed multiply low for v2i16
define void @test_pmul_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-RV32-LABEL: test_pmul_h:
diff --git a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
index 1838739..ba04c95 100644
--- a/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
+++ b/llvm/test/CodeGen/RISCV/rvp-ext-rv64.ll
@@ -1328,6 +1328,169 @@
ret void
}
+; Test packed multiply high rounding signed
+define void @test_pmulhr_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhr_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhr_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhr.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding unsigned
+define void @test_pmulhru_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhru.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhru_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhru_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhru.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = zext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+; Test packed multiply high rounding signed-unsigned
+define void @test_pmulhrsu_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = sext <4 x i16> %a to <4 x i32>
+ %b_ext = zext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhrsu_h_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_h_commuted:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.h a1, a2, a1
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <4 x i16>, ptr %a_ptr
+ %b = load <4 x i16>, ptr %b_ptr
+ %a_ext = zext <4 x i16> %a to <4 x i32>
+ %b_ext = sext <4 x i16> %b to <4 x i32>
+ %mul = mul <4 x i32> %a_ext, %b_ext
+ %add = add <4 x i32> %mul, <i32 32768, i32 32768, i32 32768, i32 32768>
+ %shift = lshr <4 x i32> %add, <i32 16, i32 16, i32 16, i32 16>
+ %res = trunc <4 x i32> %shift to <4 x i16>
+ store <4 x i16> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhrsu_w(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_w:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.w a1, a1, a2
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = sext <2 x i32> %a to <2 x i64>
+ %b_ext = zext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
+define void @test_pmulhrsu_w_commuted(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
+; CHECK-LABEL: test_pmulhrsu_w_commuted:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ld a1, 0(a1)
+; CHECK-NEXT: ld a2, 0(a2)
+; CHECK-NEXT: pmulhrsu.w a1, a2, a1
+; CHECK-NEXT: sd a1, 0(a0)
+; CHECK-NEXT: ret
+ %a = load <2 x i32>, ptr %a_ptr
+ %b = load <2 x i32>, ptr %b_ptr
+ %a_ext = zext <2 x i32> %a to <2 x i64>
+ %b_ext = sext <2 x i32> %b to <2 x i64>
+ %mul = mul <2 x i64> %a_ext, %b_ext
+ %add = add <2 x i64> %mul, <i64 2147483648, i64 2147483648>
+ %shift = lshr <2 x i64> %add, <i64 32, i64 32>
+ %res = trunc <2 x i64> %shift to <2 x i32>
+ store <2 x i32> %res, ptr %ret_ptr
+ ret void
+}
+
; Test packed multiply low for v4i16
define void @test_pmul_h(ptr %ret_ptr, ptr %a_ptr, ptr %b_ptr) {
; CHECK-LABEL: test_pmul_h: