[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla (#135634)

GitOrigin-RevId: e9c9c33fa4e26b7e18947dfefa960f68945d1899
diff --git a/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 3a990f8..7385bb7 100644
--- a/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -147,11 +147,9 @@
     "$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
 }
 
-def SdotOp : ArmSVE_Op<"sdot",
-               [Pure,
-               AllTypesMatch<["src1", "src2"]>,
-               AllTypesMatch<["acc", "dst"]>,
-             ]> {
+def SdotOp : ArmSVE_Op<"sdot", [Pure,
+                                AllTypesMatch<["src1", "src2"]>,
+                                AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Vector-vector dot product and accumulate op";
   let description = [{
     SDOT: Signed integer addition of dot product.
@@ -178,11 +176,9 @@
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def SmmlaOp : ArmSVE_Op<"smmla",
-                [Pure,
-                AllTypesMatch<["src1", "src2"]>,
-                AllTypesMatch<["acc", "dst"]>,
-              ]> {
+def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
+                                  AllTypesMatch<["src1", "src2"]>,
+                                  AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Matrix-matrix multiply and accumulate op";
   let description = [{
     SMMLA: Signed integer matrix multiply-accumulate.
@@ -210,11 +206,9 @@
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def UdotOp : ArmSVE_Op<"udot",
-               [Pure,
-               AllTypesMatch<["src1", "src2"]>,
-               AllTypesMatch<["acc", "dst"]>,
-             ]> {
+def UdotOp : ArmSVE_Op<"udot", [Pure,
+                                AllTypesMatch<["src1", "src2"]>,
+                                AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Vector-vector dot product and accumulate op";
   let description = [{
     UDOT: Unsigned integer addition of dot product.
@@ -241,11 +235,9 @@
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
-def UmmlaOp : ArmSVE_Op<"ummla",
-                [Pure,
-                AllTypesMatch<["src1", "src2"]>,
-                AllTypesMatch<["acc", "dst"]>,
-              ]> {
+def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
+                                  AllTypesMatch<["src1", "src2"]>,
+                                  AllTypesMatch<["acc", "dst"]>]> {
   let summary = "Matrix-matrix multiply and accumulate op";
   let description = [{
     UMMLA: Unsigned integer matrix multiply-accumulate.
@@ -273,14 +265,42 @@
     "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
 }
 
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+                                    AllTypesMatch<["src1", "src2"]>,
+                                    AllTypesMatch<["acc", "dst"]>]> {
+  let summary = "Matrix-matrix multiply and accumulate op";
+  let description = [{
+    USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+    The unsigned by signed integer matrix multiply-accumulate operation
+    multiplies the 2×8 matrix of unsigned 8-bit integer values held
+    the first source vector by the 8×2 matrix of signed 8-bit integer
+    values in the second source vector. The resulting 2×2 widened 32-bit
+    integer matrix product is then added to the 32-bit integer matrix
+    accumulator.
+
+    Source:
+    https://developer.arm.com/documentation/100987/0000
+  }];
+  // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+  let arguments = (ins
+          ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+          ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+  );
+  let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+  let assemblyFormat =
+    "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
 class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
       "expected corresponding svbool type widened to [16]xi1",
       lhsArg, rhsArg,
       "VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
 
 def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
-                            [Pure, SvboolTypeConstraint<"result", "source">]>
-{
+                                    [Pure,
+                                     SvboolTypeConstraint<"result", "source">]> {
   let summary = "Convert a svbool type to a SVE predicate type";
   let description = [{
     Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
@@ -313,8 +333,8 @@
 }
 
 def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
-                            [Pure, SvboolTypeConstraint<"source", "result">]>
-{
+                                  [Pure,
+                                   SvboolTypeConstraint<"source", "result">]> {
   let summary = "Convert a SVE predicate type to a svbool type";
   let description = [{
     Converts SVE predicate types (or vectors of predicate types, e.g.
@@ -356,10 +376,9 @@
   Scalable1DVectorOfLength<16, [I8]>],
   "an SVE vector with element size <= 64-bit">;
 
-def ZipX2Op  : ArmSVE_Op<"zip.x2", [
-  Pure,
-  AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
-> {
+def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
+                                   AllTypesMatch<["sourceV1", "sourceV2",
+                                                  "resultV1", "resultV2"]>]> {
   let summary = "Multi-vector two-way zip op";
 
   let description = [{
@@ -400,12 +419,11 @@
   }];
 }
 
-def ZipX4Op  : ArmSVE_Op<"zip.x4", [
-  Pure,
-  AllTypesMatch<[
-    "sourceV1", "sourceV2", "sourceV3", "sourceV4",
-    "resultV1", "resultV2", "resultV3", "resultV4"]>]
-> {
+def ZipX4Op
+  : ArmSVE_Op<"zip.x4",
+              [Pure,
+               AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
+                              "resultV1", "resultV2", "resultV3", "resultV4"]>]> {
   let summary = "Multi-vector four-way zip op";
 
   let description = [{
@@ -463,10 +481,7 @@
   }];
 }
 
-def PselOp : ArmSVE_Op<"psel", [
-  Pure,
-  AllTypesMatch<["p1", "result"]>,
-]> {
+def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
   let summary = "Predicate select";
 
   let description = [{
@@ -571,6 +586,10 @@
   ArmSVE_IntrBinaryOverloadedOp<"smmla">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
 
+def UsmmlaIntrOp :
+  ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+  Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
 def SdotIntrOp :
   ArmSVE_IntrBinaryOverloadedOp<"sdot">,
   Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
diff --git a/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 536373b..35f2a02 100644
--- a/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,7 @@
 using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
 using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
 using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
 using DupQLaneLowering =
     OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
 using ScalableMaskedAddIOpLowering =
@@ -206,6 +207,7 @@
                SmmlaOpLowering,
                UdotOpLowering,
                UmmlaOpLowering,
+               UsmmlaOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering,
                SdotOpLowering>(converter);
@@ -234,6 +236,7 @@
                     SmmlaIntrOp,
                     UdotIntrOp,
                     UmmlaIntrOp,
+                    UsmmlaIntrOp,
                     WhileLTIntrOp,
                     ZipX2IntrOp,
                     ZipX4IntrOp,
@@ -254,6 +257,7 @@
                       SmmlaOp,
                       UdotOp,
                       UmmlaOp,
+                      UsmmlaOp,
                       ZipX2Op,
                       ZipX4Op,
                       SdotOp>();
diff --git a/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 650b3e7..8c658db 100644
--- a/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -48,6 +48,18 @@
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>)
+    -> vector<[4]xi32> {
+  // CHECK: arm_sve.intr.usmmla
+  %0 = arm_sve.usmmla %c, %a, %b :
+               vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/test/Dialect/ArmSVE/roundtrip.mlir b/test/Dialect/ArmSVE/roundtrip.mlir
index 0f0c5a8..64e0cff 100644
--- a/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/test/Dialect/ArmSVE/roundtrip.mlir
@@ -44,6 +44,17 @@
 
 // -----
 
+func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
+                    %b: vector<[16]xi8>,
+                    %c: vector<[4]xi32>) -> vector<[4]xi32> {
+  // CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
+  %0 = arm_sve.usmmla %c, %a, %b :
+             vector<[16]xi8> to vector<[4]xi32>
+  return %0 : vector<[4]xi32>
+}
+
+// -----
+
 func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
                             %b: vector<[4]xi32>,
                             %c: vector<[4]xi32>,
diff --git a/test/Target/LLVMIR/arm-sve.mlir b/test/Target/LLVMIR/arm-sve.mlir
index 14c68b2..da71cb5 100644
--- a/test/Target/LLVMIR/arm-sve.mlir
+++ b/test/Target/LLVMIR/arm-sve.mlir
@@ -48,6 +48,18 @@
   llvm.return %0 : vector<[4]xi32>
 }
 
+// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
+llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
+                         %arg1: vector<[16]xi8>,
+                         %arg2: vector<[4]xi32>)
+                         -> vector<[4]xi32> {
+  // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
+  %0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
+    (vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
+        -> vector<[4]xi32>
+  llvm.return %0 : vector<[4]xi32>
+}
+
 // CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
 llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
                           %arg1: vector<[4]xi32>,