Fix array attribute in bindings for linalg.init_tensor
Reviewed By: ftynse
Differential Revision: https://reviews.llvm.org/D101998
diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py
index 4714e69..0aea4e6 100644
--- a/mlir/python/mlir/dialects/_linalg_ops_ext.py
+++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py
@@ -74,9 +74,9 @@
result_type = RankedTensorType.get(sizes, element_type)
static_size_ints = sizes
- index_type = IndexType.get(context)
+ i64_type = IntegerType.get_signless(64)
attributes["static_sizes"] = ArrayAttr.get(
- [IntegerAttr.get(index_type, s) for s in static_size_ints],
+ [IntegerAttr.get(i64_type, s) for s in static_size_ints],
context=context)
op = self.build_generic(results=[result_type],
operands=operands,
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index f153ecb..de8e4b2 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -38,6 +38,17 @@
print(module)
+# CHECK-LABEL: TEST: testInitTensorStaticSizesAttribute
+@run
+def testInitTensorStaticSizesAttribute():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ op = linalg.InitTensorOp([3, 4], f32)
+ # CHECK: [3, 4]
+ print(op.attributes['static_sizes'])
+
# CHECK-LABEL: TEST: testFill
@run
def testFill():
@@ -153,7 +164,7 @@
# CHECK-NEXT: std.mulf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: std.addf{{.*}} (f32, f32) -> f32
# CHECK-NEXT: linalg.yield{{.*}} (f32) -> ()
- # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
+ # CHECK-NEXT: {linalg.memoized_indexing_maps{{.*}}operand_segment_sizes = dense<[2, 1]> : vector<2xi32>} :
# CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return linalg.matmul(lhs, rhs, outs=[init_result.result])