[mlir][SPIRV] Do not rewrite CompositeInsert for coopmatrix (#137837)
When rewriting multiple CompositeInserts to CompositeConstruct, we need
to know the number of elements of the result type. However, we cannot
query the number of elements for cooperative matrix types.
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
index f38282f..2e31172 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/RewriteInsertsPass.cpp
@@ -84,6 +84,9 @@
LogicalResult RewriteInsertsPass::collectInsertionChain(
spirv::CompositeInsertOp op,
SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
+ if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
+ return failure();
+
auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
// TODO: handle nested composite object.
if (indicesArrayAttr.size() == 1) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
index 6d755be..a83c3f7d 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/rewrite-inserts.mlir
@@ -29,3 +29,15 @@
spirv.ReturnValue %3 : vector<3xf32>
}
}
+
+// -----
+
+spirv.module Logical GLSL450 {
+ spirv.func @insertCoopMatrix(%value : f32) -> !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> "None" {
+ %0 = spirv.Undef : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ // CHECK: spirv.CompositeInsert {{%.*}}, {{%.*}} : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ %1 = spirv.CompositeInsert %value, %0[0 : i32] : f32 into !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+ spirv.ReturnValue %1 : !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+ }
+}