[RISCV] Initial codegen support for zvqdotq extension (#137039)

This patch adds pattern matching for the basic usages of the dot product
instructions introduced by the experimental zvqdotq extension. It
specifically only handles the case where the pattern is feeding a i32
sum reduction as we need to reassociate the reduction tree to use these
instructions.

The vecreduce_add (sext) and vecreduce_add (zext) cases are included
mostly to exercise the VX matchers. For the generic matching, we fail to
match due to an order of combine issue which results in the bitcast
being separated from the splat.

I chose to do this lowering as an early combine so as to avoid having to
integrate the entire logic into the reduction lowering flow. In
particular, that would get a lot more complicated as we extend this to
handle add-trees feeding the reductions.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 86f8873..698b951 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -6971,7 +6971,7 @@
          Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(
-      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
+      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
       RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
       "adding target specific op should update this function");
   if (Opcode >= RISCVISD::ADD_VL && Opcode <= RISCVISD::VFMAX_VL)
@@ -6995,7 +6995,7 @@
          Opcode <= RISCVISD::LAST_STRICTFP_OPCODE &&
          "not a RISC-V target specific op");
   static_assert(
-      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 134 &&
+      RISCVISD::LAST_VL_VECTOR_OP - RISCVISD::FIRST_VL_VECTOR_OP == 139 &&
       RISCVISD::LAST_STRICTFP_OPCODE - RISCVISD::FIRST_STRICTFP_OPCODE == 21 &&
       "adding target specific op should update this function");
   if (Opcode >= RISCVISD::TRUNCATE_VECTOR_VL && Opcode <= RISCVISD::SETCC_VL)
@@ -18101,6 +18101,118 @@
                      DAG.getBuildVector(VT, DL, RHSOps));
 }
 
+static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
+                          const SDLoc &DL, SelectionDAG &DAG,
+                          const RISCVSubtarget &Subtarget) {
+  assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
+         RISCVISD::VQDOTSU_VL == Opc);
+  MVT VT = Op0.getSimpleValueType();
+  assert(VT == Op1.getSimpleValueType() &&
+         VT.getVectorElementType() == MVT::i32);
+
+  assert(VT.isFixedLengthVector());
+  MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+  SDValue Passthru = convertToScalableVector(
+      ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
+  Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
+  Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
+
+  auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
+  const unsigned Policy = RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC;
+  SDValue PolicyOp = DAG.getTargetConstant(Policy, DL, Subtarget.getXLenVT());
+  SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
+                                   {Op0, Op1, Passthru, Mask, VL, PolicyOp});
+  return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
+}
+
+static MVT getQDOTXResultType(MVT OpVT) {
+  ElementCount OpEC = OpVT.getVectorElementCount();
+  assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
+  return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
+}
+
+static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
+                                         SelectionDAG &DAG,
+                                         const RISCVSubtarget &Subtarget,
+                                         const RISCVTargetLowering &TLI) {
+  // Note: We intentionally do not check the legality of the reduction type.
+  // We want to handle the m4/m8 *src*  types, and thus need to let illegal
+  // intermediate types flow through here.
+  if (InVec.getValueType().getVectorElementType() != MVT::i32 ||
+      !InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
+    return SDValue();
+
+  // reduce (zext a) <--> reduce (mul zext a. zext 1)
+  // reduce (sext a) <--> reduce (mul sext a. sext 1)
+  if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
+      InVec.getOpcode() == ISD::SIGN_EXTEND) {
+    SDValue A = InVec.getOperand(0);
+    if (A.getValueType().getVectorElementType() != MVT::i8 ||
+        !TLI.isTypeLegal(A.getValueType()))
+      return SDValue();
+
+    MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
+    A = DAG.getBitcast(ResVT, A);
+    SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
+
+    bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
+    unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
+    return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+  }
+
+  // mul (sext, sext) -> vqdot
+  // mul (zext, zext) -> vqdotu
+  // mul (sext, zext) -> vqdotsu
+  // mul (zext, sext) -> vqdotsu (swapped)
+  // TODO: Improve .vx handling - we end up with a sub-vector insert
+  // which confuses the splat pattern matching.  Also, match vqdotus.vx
+  if (InVec.getOpcode() != ISD::MUL)
+    return SDValue();
+
+  SDValue A = InVec.getOperand(0);
+  SDValue B = InVec.getOperand(1);
+  unsigned Opc = 0;
+  if (A.getOpcode() == B.getOpcode()) {
+    if (A.getOpcode() == ISD::SIGN_EXTEND)
+      Opc = RISCVISD::VQDOT_VL;
+    else if (A.getOpcode() == ISD::ZERO_EXTEND)
+      Opc = RISCVISD::VQDOTU_VL;
+    else
+      return SDValue();
+  } else {
+    if (B.getOpcode() != ISD::ZERO_EXTEND)
+      std::swap(A, B);
+    if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
+      return SDValue();
+    Opc = RISCVISD::VQDOTSU_VL;
+  }
+  assert(Opc);
+
+  if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
+      A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
+      !TLI.isTypeLegal(A.getValueType()))
+    return SDValue();
+
+  MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
+  A = DAG.getBitcast(ResVT, A.getOperand(0));
+  B = DAG.getBitcast(ResVT, B.getOperand(0));
+  return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
+}
+
+static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
+                                       const RISCVSubtarget &Subtarget,
+                                       const RISCVTargetLowering &TLI) {
+  if (!Subtarget.hasStdExtZvqdotq())
+    return SDValue();
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue InVec = N->getOperand(0);
+  if (SDValue V = foldReduceOperandViaVQDOT(InVec, DL, DAG, Subtarget, TLI))
+    return DAG.getNode(ISD::VECREDUCE_ADD, DL, VT, V);
+  return SDValue();
+}
+
 static SDValue performINSERT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG,
                                                const RISCVSubtarget &Subtarget,
                                                const RISCVTargetLowering &TLI) {
@@ -19878,8 +19990,11 @@
 
     return SDValue();
   }
-  case ISD::CTPOP:
   case ISD::VECREDUCE_ADD:
+    if (SDValue V = performVECREDUCECombine(N, DAG, Subtarget, *this))
+      return V;
+    [[fallthrough]];
+  case ISD::CTPOP:
     if (SDValue V = combineToVCPOP(N, DAG, Subtarget))
       return V;
     break;
@@ -22401,6 +22516,9 @@
   NODE_NAME_CASE(RI_VUNZIP2A_VL)
   NODE_NAME_CASE(RI_VUNZIP2B_VL)
   NODE_NAME_CASE(RI_VEXTRACT)
+  NODE_NAME_CASE(VQDOT_VL)
+  NODE_NAME_CASE(VQDOTU_VL)
+  NODE_NAME_CASE(VQDOTSU_VL)
   NODE_NAME_CASE(READ_CSR)
   NODE_NAME_CASE(WRITE_CSR)
   NODE_NAME_CASE(SWAP_CSR)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index ba24a0c..3f1fce5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -416,7 +416,12 @@
   RI_VUNZIP2A_VL,
   RI_VUNZIP2B_VL,
 
-  LAST_VL_VECTOR_OP = RI_VUNZIP2B_VL,
+  // zvqdot instructions with additional passthru, mask and VL operands
+  VQDOT_VL,
+  VQDOTU_VL,
+  VQDOTSU_VL,
+
+  LAST_VL_VECTOR_OP = VQDOTSU_VL,
 
   // XRivosVisni
   // VEXTRACT matches the semantics of ri.vextract.x.v. The result is always
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
index 205fffd..6018958 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZvqdotq.td
@@ -26,3 +26,34 @@
   def VQDOTSU_VX : VALUVX<0b101010, OPMVX, "vqdotsu.vx">;
   def VQDOTUS_VX : VALUVX<0b101110, OPMVX, "vqdotus.vx">;
 } // Predicates = [HasStdExtZvqdotq]
+
+
+def riscv_vqdot_vl : SDNode<"RISCVISD::VQDOT_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotu_vl : SDNode<"RISCVISD::VQDOTU_VL", SDT_RISCVIntBinOp_VL>;
+def riscv_vqdotsu_vl : SDNode<"RISCVISD::VQDOTSU_VL", SDT_RISCVIntBinOp_VL>;
+
+multiclass VPseudoVQDOT_VV_VX {
+  foreach m = MxSet<32>.m in {
+    defm "" : VPseudoBinaryV_VV<m>,
+            SchedBinary<"WriteVIALUV", "ReadVIALUV", "ReadVIALUV", m.MX,
+                        forcePassthruRead=true>;
+    defm "" : VPseudoBinaryV_VX<m>,
+            SchedBinary<"WriteVIALUX", "ReadVIALUV", "ReadVIALUX", m.MX,
+                        forcePassthruRead=true>;
+  }
+}
+
+// TODO: Add pseudo and patterns for vqdotus.vx
+// TODO: Add isCommutable for VQDOT and VQDOTU
+let Predicates = [HasStdExtZvqdotq], mayLoad = 0, mayStore = 0,
+    hasSideEffects = 0 in {
+  defm PseudoVQDOT : VPseudoVQDOT_VV_VX;
+  defm PseudoVQDOTU : VPseudoVQDOT_VV_VX;
+  defm PseudoVQDOTSU : VPseudoVQDOT_VV_VX;
+}
+
+defvar AllE32Vectors = [VI32MF2, VI32M1, VI32M2, VI32M4, VI32M8];
+defm : VPatBinaryVL_VV_VX<riscv_vqdot_vl, "PseudoVQDOT", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotu_vl, "PseudoVQDOTU", AllE32Vectors>;
+defm : VPatBinaryVL_VV_VX<riscv_vqdotsu_vl, "PseudoVQDOTSU", AllE32Vectors>;
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
index 25192ea..e48bc9c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll
@@ -1,21 +1,31 @@
 ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
-; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
-; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,NODOT
+; RUN: llc -mtriple=riscv32 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+experimental-zvqdotq -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,DOT,DOT64
 
 define i32 @vqdot_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdot_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vsext.vf2 v14, v9
-; CHECK-NEXT:    vwmul.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdot_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vsext.vf2 v14, v9
+; NODOT-NEXT:    vwmul.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdot_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdot.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.sext = sext <16 x i8> %b to <16 x i32>
@@ -63,17 +73,27 @@
 }
 
 define i32 @vqdotu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotu_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
-; CHECK-NEXT:    vwmulu.vv v10, v8, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
-; CHECK-NEXT:    vwredsumu.vs v8, v10, v8
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e8, m1, ta, ma
+; NODOT-NEXT:    vwmulu.vv v10, v8, v9
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vsetvli zero, zero, e16, m2, ta, ma
+; NODOT-NEXT:    vwredsumu.vs v8, v10, v8
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.zext = zext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -102,17 +122,27 @@
 }
 
 define i32 @vqdotsu_vv(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwmulsu.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vzext.vf2 v14, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -122,17 +152,27 @@
 }
 
 define i32 @vqdotsu_vv_swapped(<16 x i8> %a, <16 x i8> %b) {
-; CHECK-LABEL: vqdotsu_vv_swapped:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
-; CHECK-NEXT:    vsext.vf2 v12, v8
-; CHECK-NEXT:    vzext.vf2 v14, v9
-; CHECK-NEXT:    vwmulsu.vv v8, v12, v14
-; CHECK-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
-; CHECK-NEXT:    vmv.s.x v12, zero
-; CHECK-NEXT:    vredsum.vs v8, v8, v12
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: vqdotsu_vv_swapped:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e16, m2, ta, ma
+; NODOT-NEXT:    vsext.vf2 v12, v8
+; NODOT-NEXT:    vzext.vf2 v14, v9
+; NODOT-NEXT:    vwmulsu.vv v8, v12, v14
+; NODOT-NEXT:    vsetvli zero, zero, e32, m4, ta, ma
+; NODOT-NEXT:    vmv.s.x v12, zero
+; NODOT-NEXT:    vredsum.vs v8, v8, v12
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT-LABEL: vqdotsu_vv_swapped:
+; DOT:       # %bb.0: # %entry
+; DOT-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT-NEXT:    vmv.v.i v10, 0
+; DOT-NEXT:    vqdotsu.vv v10, v8, v9
+; DOT-NEXT:    vmv.s.x v8, zero
+; DOT-NEXT:    vredsum.vs v8, v10, v8
+; DOT-NEXT:    vmv.x.s a0, v8
+; DOT-NEXT:    ret
 entry:
   %a.sext = sext <16 x i8> %a to <16 x i32>
   %b.zext = zext <16 x i8> %b to <16 x i32>
@@ -181,14 +221,38 @@
 }
 
 define i32 @reduce_of_sext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_sext:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vsext.vf4 v12, v8
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: reduce_of_sext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT:    vsext.vf4 v12, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_sext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT:    vmv.v.i v9, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdot.vx v9, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v9, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_sext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT:    vmv.v.i v9, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdot.vx v9, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v9, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
 entry:
   %a.ext = sext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)
@@ -196,14 +260,38 @@
 }
 
 define i32 @reduce_of_zext(<16 x i8> %a) {
-; CHECK-LABEL: reduce_of_zext:
-; CHECK:       # %bb.0: # %entry
-; CHECK-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
-; CHECK-NEXT:    vzext.vf4 v12, v8
-; CHECK-NEXT:    vmv.s.x v8, zero
-; CHECK-NEXT:    vredsum.vs v8, v12, v8
-; CHECK-NEXT:    vmv.x.s a0, v8
-; CHECK-NEXT:    ret
+; NODOT-LABEL: reduce_of_zext:
+; NODOT:       # %bb.0: # %entry
+; NODOT-NEXT:    vsetivli zero, 16, e32, m4, ta, ma
+; NODOT-NEXT:    vzext.vf4 v12, v8
+; NODOT-NEXT:    vmv.s.x v8, zero
+; NODOT-NEXT:    vredsum.vs v8, v12, v8
+; NODOT-NEXT:    vmv.x.s a0, v8
+; NODOT-NEXT:    ret
+;
+; DOT32-LABEL: reduce_of_zext:
+; DOT32:       # %bb.0: # %entry
+; DOT32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT32-NEXT:    vmv.v.i v9, 0
+; DOT32-NEXT:    lui a0, 4112
+; DOT32-NEXT:    addi a0, a0, 257
+; DOT32-NEXT:    vqdotu.vx v9, v8, a0
+; DOT32-NEXT:    vmv.s.x v8, zero
+; DOT32-NEXT:    vredsum.vs v8, v9, v8
+; DOT32-NEXT:    vmv.x.s a0, v8
+; DOT32-NEXT:    ret
+;
+; DOT64-LABEL: reduce_of_zext:
+; DOT64:       # %bb.0: # %entry
+; DOT64-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
+; DOT64-NEXT:    vmv.v.i v9, 0
+; DOT64-NEXT:    lui a0, 4112
+; DOT64-NEXT:    addiw a0, a0, 257
+; DOT64-NEXT:    vqdotu.vx v9, v8, a0
+; DOT64-NEXT:    vmv.s.x v8, zero
+; DOT64-NEXT:    vredsum.vs v8, v9, v8
+; DOT64-NEXT:    vmv.x.s a0, v8
+; DOT64-NEXT:    ret
 entry:
   %a.ext = zext <16 x i8> %a to <16 x i32>
   %res = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %a.ext)