[mlir][ArmSME] Support decomposing constant splats into ArmSME tiles (#88762)

This adds a simple rewrite/legalization to decompose constant splats
larger than a single ArmSME tile into multiple SME virtual tile sized
splats. E.g. a constant splat to `vector<[8]x[8]xi32>` would decompose
into four `vector<[4]x[4]xi32>` splats.

GitOrigin-RevId: dadcaf82274805456b7d85131cf94f921b5398b7
diff --git a/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c6..b595c6d 100644
--- a/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@
   return (vectorRows * vectorCols) / (minNumElts * minNumElts);
 }
 
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+    : public OneToNOpConversionPattern<arith::ConstantOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+    if (!vectorType || !denseAttr || !denseAttr.isSplat())
+      return failure();
+
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(constantOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+    auto tileSplat = rewriter.create<arith::ConstantOp>(
+        constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+                       adaptor.getResultMapping());
+
+    return success();
+  }
+};
+
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
-    patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+    patterns.add<LegalizeArithConstantOpsByDecomposition,
+                 LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
     populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/test/Dialect/ArmSME/vector-legalization.mlir b/test/Dialect/ArmSME/vector-legalization.mlir
index f8be697..f43ef1c 100644
--- a/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/test/Dialect/ArmSME/vector-legalization.mlir
@@ -433,3 +433,14 @@
   %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
   return %cast : vector<[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @multi_tile_splat
+func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
+{
+  // CHECK: %[[SPLAT:.*]] = arith.constant dense<42> : vector<[4]x[4]xi32>
+  // CHECK-NEXT: return %[[SPLAT]], %[[SPLAT]], %[[SPLAT]], %[[SPLAT]] : vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>, vector<[4]x[4]xi32>
+  %0 = arith.constant dense<42> : vector<[8]x[8]xi32>
+  return %0 : vector<[8]x[8]xi32>
+}