[mlir][spirv] Tighten types of SPIR-V TOSA op definitions (#192623)
Tighten the SPIR-V TOSA op definitions by introducing stricter named
type constraints and aligning verifier coverage with the new type
surface.
Remove implication checks that are now enforced directly by
operand/result type constraints.
Drop the corresponding negative tests that no longer exercise those
verifier paths.
Signed-off-by: Davide Grohmann <davide.grohmann@arm.com>
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
index a83ea4b..c873e30 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaOps.td
@@ -77,11 +77,11 @@
SPIRV_TosaElementwiseUnaryOp<mnemonic, opcode, !listconcat(traits, [Pure])> {
let arguments = (ins
- SPIRV_TosaFloat_TensorArm: $input1
+ SPIRV_F16OrF32OrBF16_TensorArm: $input1
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -115,12 +115,12 @@
SPIRV_TosaElementwiseBinaryOp<mnemonic, opcode, traits> {
let arguments = (ins
- SPIRV_TosaInteger_TensorArm: $input1,
- SPIRV_TosaInteger_TensorArm: $input2
+ SPIRV_I8OrI16OrI32_TensorArm: $input1,
+ SPIRV_I8OrI16OrI32_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaInteger_TensorArm: $output
+ SPIRV_I8OrI16OrI32_TensorArm: $output
);
let assemblyFormat = [{
@@ -159,8 +159,6 @@
TypeConstraintImplicationOn<"input", BF16, "weight", [BF16]>,
TypeConstraintImplicationOn<"input", F16, "weight", [F16]>,
TypeConstraintImplicationOn<"input", F32, "weight", [F32]>,
- TypeConstraintImplicationOn<"input", AnyInteger, "input", [I8, I16]>,
- TypeConstraintImplicationOn<"weight", AnyInteger, "weight", [I8]>,
TypeImpliesAccType<"input", I8, ["INT32"]>,
TypeImpliesAccType<"input", I16, ["INT48"]>,
TypeImpliesAccType<"input", F16, ["FP16", "FP32"]>,
@@ -189,8 +187,8 @@
MatchBroadcastableShapes<"input1", "input2", "output">])> {
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
@@ -251,11 +249,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_Int32_TensorArmUpTo5D: $output
+ SPIRV_I32_TensorArmUpTo5D: $output
);
let assemblyFormat = [{
@@ -302,17 +300,17 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtAccTypeAttr: $acc_type,
- SPIRV_TosaNumerical_TensorArm4D: $input,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -358,20 +356,20 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_TosaNumerical_TensorArm4D: $input,
- SPIRV_TosaNumerical_TensorArm4D: $weight,
- SPIRV_TosaNumerical_TensorArm1D: $bias,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -413,20 +411,20 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength6Attr: $pad,
- SPIRV_Int32_1DTensorArmOfLength3Attr: $stride,
- SPIRV_Int32_1DTensorArmOfLength3Attr: $dilation,
+ SPIRV_I32_1DTensorArmOfLength6Attr: $pad,
+ SPIRV_I32_1DTensorArmOfLength3Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength3Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_TosaNumerical_TensorArm5D: $input,
- SPIRV_TosaNumerical_TensorArm5D: $weight,
- SPIRV_TosaNumerical_TensorArm1D: $bias,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D: $input,
+ SPIRV_I8OrF16OrF32OrBF16_TensorArm5D: $weight,
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm5D: $output
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm5D: $output
);
let assemblyFormat = [{
@@ -469,20 +467,20 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $dilation,
+ SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $dilation,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_TosaNumerical_TensorArm4D: $input,
- SPIRV_TosaNumerical_TensorArm4D: $weight,
- SPIRV_TosaNumerical_TensorArm1D: $bias,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -526,12 +524,12 @@
let arguments = (ins
SPIRV_BoolConstAttr: $inverse,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_Float32_TensorArm3D: $input_real,
- SPIRV_Float32_TensorArm3D: $input_imag
+ SPIRV_F32_TensorArm3D: $input_real,
+ SPIRV_F32_TensorArm3D: $input_imag
);
let results = (outs
- SPIRV_Struct_2_Float32_TensorArm3D: $output
+ SPIRV_Struct_2_F32_TensorArm3D: $output
);
let assemblyFormat = [{
@@ -581,14 +579,14 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm3D: $A,
- SPIRV_TosaNumerical_TensorArm3D: $B,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $A_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $B_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $A,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D: $B,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $A_zp,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $B_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm3D: $output
+ SPIRV_I32OrI64OrF16OrF32_TensorArm3D: $output
);
let assemblyFormat = [{
@@ -632,15 +630,15 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength2Attr: $kernel,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
- SPIRV_Int32_1DTensorArmOfLength4Attr: $pad,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $kernel,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength4Attr: $pad,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm4D: $input
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -687,11 +685,11 @@
let arguments = (ins
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_Float32_TensorArm3D: $input_real
+ SPIRV_F32_TensorArm3D: $input_real
);
let results = (outs
- SPIRV_Struct_2_Float32_TensorArm3D: $output
+ SPIRV_Struct_2_F32_TensorArm3D: $output
);
let assemblyFormat = [{
@@ -732,19 +730,19 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength4Attr: $out_pad,
- SPIRV_Int32_1DTensorArmOfLength2Attr: $stride,
+ SPIRV_I32_1DTensorArmOfLength4Attr: $out_pad,
+ SPIRV_I32_1DTensorArmOfLength2Attr: $stride,
SPIRV_TosaExtAccTypeAttr: $acc_type,
SPIRV_BoolConstAttr: $local_bound,
- SPIRV_TosaNumerical_TensorArm4D: $input,
- SPIRV_TosaNumerical_TensorArm4D: $weight,
- SPIRV_TosaNumerical_TensorArm1D: $bias,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $weight_zp
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
+ SPIRV_I8OrF16OrF32OrBF16_TensorArm4D: $weight,
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D: $bias,
+ SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1: $weight_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -787,14 +785,14 @@
}];
let arguments = (ins
- SPIRV_TosaNumericalAttr: $min_val,
- SPIRV_TosaNumericalAttr: $max_val,
+ SPIRV_I8OrI16OrF16OrF32OrBF16ConstAttr: $min_val,
+ SPIRV_I8OrI16OrF16OrF32OrBF16ConstAttr: $max_val,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -832,11 +830,11 @@
}];
let arguments = (ins
- SPIRV_TosaFloat_TensorArm: $input
+ SPIRV_F16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -873,11 +871,11 @@
}];
let arguments = (ins
- SPIRV_TosaFloat_TensorArm: $input
+ SPIRV_F16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -913,11 +911,11 @@
}];
let arguments = (ins
- SPIRV_TosaFloat_TensorArm: $input
+ SPIRV_F16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -954,12 +952,12 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -996,12 +994,12 @@
let arguments = (ins
SPIRV_BoolConstAttr: $round,
- SPIRV_TosaInteger_TensorArm: $input1,
- SPIRV_TosaInteger_TensorArm: $input2
+ SPIRV_I8OrI16OrI32_TensorArm: $input1,
+ SPIRV_I8OrI16OrI32_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaInteger_TensorArm: $output
+ SPIRV_I8OrI16OrI32_TensorArm: $output
);
let assemblyFormat = [{
@@ -1099,12 +1097,12 @@
}];
let arguments = (ins
- SPIRV_Int32_TensorArm: $input1,
- SPIRV_Int32_TensorArm: $input2
+ SPIRV_I32_TensorArm: $input1,
+ SPIRV_I32_TensorArm: $input2
);
let results = (outs
- SPIRV_Int32_TensorArm: $output
+ SPIRV_I32_TensorArm: $output
);
let assemblyFormat = [{
@@ -1222,10 +1220,7 @@
}
-def SPIRV_TosaMaximumOp : SPIRV_TosaElementwiseBinaryOp<"Maximum", 25, [Pure,
- TypeConstraintImplicationOn<"input1", AnyInteger, "input1", [I32]>,
- TypeConstraintImplicationOn<"input2", AnyInteger, "input2", [I32]>,
- TypeConstraintImplicationOn<"output", AnyInteger, "output", [I32]>]> {
+def SPIRV_TosaMaximumOp : SPIRV_TosaElementwiseBinaryOp<"Maximum", 25, [Pure]> {
let summary = "Maximum.";
let description = [{
@@ -1247,12 +1242,12 @@
let arguments = (ins
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1264,10 +1259,7 @@
}
-def SPIRV_TosaMinimumOp : SPIRV_TosaElementwiseBinaryOp<"Minimum", 26, [Pure,
- TypeConstraintImplicationOn<"input1", AnyInteger, "input1", [I32]>,
- TypeConstraintImplicationOn<"input2", AnyInteger, "input2", [I32]>,
- TypeConstraintImplicationOn<"output", AnyInteger, "output", [I32]>]> {
+def SPIRV_TosaMinimumOp : SPIRV_TosaElementwiseBinaryOp<"Minimum", 26, [Pure]> {
let summary = "Minimum.";
let description = [{
@@ -1289,12 +1281,12 @@
let arguments = (ins
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1341,13 +1333,13 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2,
- SPIRV_Int8_1DTensorArmOfLength1: $shift
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input2,
+ SPIRV_I8_1DTensorArmOfLength1: $shift
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1380,12 +1372,12 @@
}];
let arguments = (ins
- SPIRV_TosaFloat_TensorArm: $input1,
- SPIRV_TosaFloat_TensorArm: $input2
+ SPIRV_F16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_F16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1396,10 +1388,7 @@
}
-def SPIRV_TosaSubOp : SPIRV_TosaElementwiseBinaryOp<"Sub", 29, [NoMemoryEffect,
- TypeConstraintImplicationOn<"input1", AnyInteger, "input1", [I32]>,
- TypeConstraintImplicationOn<"input2", AnyInteger, "input2", [I32]>,
- TypeConstraintImplicationOn<"output", AnyInteger, "output", [I32]>]> {
+def SPIRV_TosaSubOp : SPIRV_TosaElementwiseBinaryOp<"Sub", 29, [NoMemoryEffect]> {
let summary = "Subtraction operator.";
let description = [{
@@ -1420,12 +1409,12 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_TensorArm: $input2
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input2
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1478,12 +1467,12 @@
}];
let arguments = (ins
- SPIRV_TosaInteger_TensorArm: $input1,
- SPIRV_TosaInteger_TensorArm1D: $table
+ SPIRV_I8OrI16_TensorArm: $input1,
+ SPIRV_I8OrI16_TensorArm1D: $table
);
let results = (outs
- SPIRV_TosaInteger_TensorArm: $output
+ SPIRV_I8OrI32_TensorArm: $output
);
let assemblyFormat = [{
@@ -1520,11 +1509,11 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input1
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1551,11 +1540,11 @@
}];
let arguments = (ins
- SPIRV_TosaInteger_TensorArm: $input1
+ SPIRV_I8OrI16OrI32_TensorArm: $input1
);
let results = (outs
- SPIRV_TosaInteger_TensorArm: $output
+ SPIRV_I8OrI16OrI32_TensorArm: $output
);
let assemblyFormat = [{
@@ -1600,11 +1589,11 @@
}];
let arguments = (ins
- SPIRV_Int32_TensorArm: $input1
+ SPIRV_I32_TensorArm: $input1
);
let results = (outs
- SPIRV_Int32_TensorArm: $output
+ SPIRV_I32_TensorArm: $output
);
let assemblyFormat = [{
@@ -1738,13 +1727,13 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm: $input1,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $input1_zp,
- SPIRV_TosaNumerical_1DTensorArmOfLength1: $output_zp
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $input1_zp,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $output_zp
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -1835,12 +1824,12 @@
let arguments = (ins
SPIRV_Bool_TensorArm: $condition,
- SPIRV_TosaAny_TensorArm: $true_value,
- SPIRV_TosaAny_TensorArm: $false_value
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $true_value,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $false_value
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let hasVerifier = 1;
@@ -2031,11 +2020,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2068,11 +2057,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
SPIRV_TosaExtNaNPropagationModeAttr: $nan_mode,
- SPIRV_TosaNumerical_TensorArm: $input
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2102,11 +2091,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- SPIRV_TosaFloat_TensorArm: $input
+ SPIRV_F16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaFloat_TensorArm: $output
+ SPIRV_F16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2139,11 +2128,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- SPIRV_TosaNumerical_TensorArm: $input
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm: $output
+ SPIRV_I32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2178,11 +2167,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- Variadic<SPIRV_TosaAny_TensorArm>: $input1
+ Variadic<SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm>: $input1
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2225,13 +2214,13 @@
}];
let arguments = (ins
- SPIRV_TosaAny_TensorArm: $input1,
- SPIRV_Int32_1DTensorArmOfEvenLength2To12: $padding,
- SPIRV_TosaAny_1DTensorArmOfLength1: $pad_const
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32_1DTensorArmOfEvenLength2To12: $padding,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1: $pad_const
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2278,12 +2267,12 @@
}];
let arguments = (ins
- SPIRV_TosaAny_TensorArm: $input1,
- SPIRV_Int32_1DTensorArmOfLength1To6: $shape
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32_1DTensorArmOfLength1To6: $shape
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2326,11 +2315,11 @@
let arguments = (ins
SPIRV_TensorArmAxisAttr: $axis,
- SPIRV_TosaAny_TensorArm: $input1
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2373,13 +2362,13 @@
}];
let arguments = (ins
- SPIRV_TosaAny_TensorArm: $input1,
- SPIRV_Int32_1DTensorArmOfLength1To6: $start,
- SPIRV_Int32_1DTensorArmOfLength1To6: $size
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32_1DTensorArmOfLength1To6: $start,
+ SPIRV_I32_1DTensorArmOfLength1To6: $size
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2427,12 +2416,12 @@
}];
let arguments = (ins
- SPIRV_TosaAny_TensorArm: $input1,
- SPIRV_Int32_1DTensorArmOfLength1To6: $multiples
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1,
+ SPIRV_I32_1DTensorArmOfLength1To6: $multiples
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2476,12 +2465,12 @@
}];
let arguments = (ins
- SPIRV_Int32_1DTensorArmOfLength1To6Attr: $perms,
- SPIRV_TosaAny_TensorArm: $input1
+ SPIRV_I32_1DTensorArmOfLength1To6Attr: $perms,
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $input1
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2523,12 +2512,12 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm3D: $values,
- SPIRV_Int32_TensorArm2D: $indices
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values,
+ SPIRV_I32_TensorArm2D: $indices
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm3D: $output
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $output
);
let assemblyFormat = [{
@@ -2577,13 +2566,13 @@
}];
let arguments = (ins
- SPIRV_TosaNumerical_TensorArm3D: $values_in,
- SPIRV_Int32_TensorArm2D: $indices,
- SPIRV_TosaNumerical_TensorArm3D: $input
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_in,
+ SPIRV_I32_TensorArm2D: $indices,
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $input
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm3D: $values_out
+ SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D: $values_out
);
let assemblyFormat = [{
@@ -2660,14 +2649,14 @@
let arguments = (ins
SPIRV_TosaExtResizeModeAttr: $mode,
- SPIRV_TosaNumerical_TensorArm4D: $input,
- SPIRV_Int32_1DTensorArmOfLength4: $scale,
- SPIRV_Int32_1DTensorArmOfLength2: $offset,
- SPIRV_Int32_1DTensorArmOfLength2: $border
+ SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D: $input,
+ SPIRV_I32_1DTensorArmOfLength4: $scale,
+ SPIRV_I32_1DTensorArmOfLength2: $offset,
+ SPIRV_I32_1DTensorArmOfLength2: $border
);
let results = (outs
- SPIRV_TosaNumerical_TensorArm4D: $output
+ SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D: $output
);
let assemblyFormat = [{
@@ -2761,11 +2750,11 @@
}];
let arguments = (ins
- SPIRV_TosaAny_TensorArm: $input
+ SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm: $input
);
let results = (outs
- SPIRV_TosaAny_TensorArm: $output
+ SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm: $output
);
let assemblyFormat = [{
@@ -2788,10 +2777,6 @@
ElementTypeMatchesScale32<"multiplier">,
TensorLengthMatchesPerChannel<"multiplier">,
TensorLengthMatchesPerChannel<"shift">,
- TypeConstraintImplicationOn<"input", I8, "output", [I8, I16, I32]>,
- TypeConstraintImplicationOn<"input", I16, "output", [I8, I16, I32]>,
- TypeConstraintImplicationOn<"input", I32, "output", [I8, I16, I32]>,
- TypeConstraintImplicationOn<"input", I64, "output", [I8, I16, I32]>,
BoolAttrTypeConstraintImplicationOn<"input_unsigned", "input", [I8, I16]>,
BoolAttrTypeConstraintImplicationOn<"output_unsigned", "input", [I8, I16]>,
BoolAttrTypeConstraintImplicationOn<"input_unsigned", "output", [I8, I16]>,
@@ -2843,15 +2828,15 @@
SPIRV_BoolConstAttr: $per_channel,
SPIRV_BoolConstAttr: $input_unsigned,
SPIRV_BoolConstAttr: $output_unsigned,
- SPIRV_TosaInteger_TensorArm: $input,
- SPIRV_Int16OrInt32_TensorArm1D: $multiplier,
- SPIRV_Int8_TensorArm1D: $shift,
- SPIRV_TosaInteger_1DTensorArmOfLength1: $input_zp,
- SPIRV_TosaInteger_1DTensorArmOfLength1: $output_zp
+ SPIRV_I8OrI16OrI32OrI64_TensorArm: $input,
+ SPIRV_I16OrI32_TensorArm1D: $multiplier,
+ SPIRV_I8_TensorArm1D: $shift,
+ SPIRV_I8OrI16OrI32OrI64_1DTensorArmOfLength1: $input_zp,
+ SPIRV_I8OrI16OrI32_1DTensorArmOfLength1: $output_zp
);
let results = (outs
- SPIRV_TosaInteger_TensorArm: $output
+ SPIRV_I8OrI16OrI32_TensorArm: $output
);
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
index 5a610aa..6c918ae 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTosaTypes.td
@@ -16,14 +16,27 @@
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
-def SPIRV_TosaInteger : AnyIntOfWidths<[8, 16, 32, 64]>;
-def SPIRV_TosaFloat : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
-def SPIRV_TosaNumerical : AnyTypeOf<[SPIRV_TosaInteger, SPIRV_TosaFloat]>;
-def SPIRV_TosaAny : AnyTypeOf<[SPIRV_TosaNumerical, SPIRV_Bool]>;
+def SPIRV_I8OrI16 : AnyIntOfWidths<[8, 16]>;
+def SPIRV_I8OrI16OrI32 : AnyIntOfWidths<[8, 16, 32]>;
+def SPIRV_I8OrI16OrI32OrI64 : AnyIntOfWidths<[8, 16, 32, 64]>;
+def SPIRV_I16OrI32 : AnyIntOfWidths<[16, 32]>;
+def SPIRV_I32OrI64 : AnyIntOfWidths<[32, 64]>;
+def SPIRV_F16OrF32OrBF16 : AnyTypeOf<[SPIRV_Float16, SPIRV_Float32, SPIRV_BFloat16KHR]>;
+def SPIRV_I8OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int8, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Int32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I32OrI64OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I32OrI64OrF16OrF32 : AnyTypeOf<[SPIRV_I32OrI64, SPIRV_Float16, SPIRV_Float32]>;
+def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32OrI64, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_Bool, SPIRV_I8OrI16OrI32, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16 : AnyTypeOf<[SPIRV_I8OrI16OrI32, SPIRV_Bool, SPIRV_F16OrF32OrBF16]>;
+def SPIRV_I8OrI32 : AnyTypeOf<[SPIRV_Int8, SPIRV_Int32]>;
def SPIRV_TensorArmAxisAttr : ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<5>]>;
def SPIRV_BoolConstAttr : ConfinedAttr<BoolAttr, []>;
-def SPIRV_TosaNumericalAttr: AnyAttrOf<[I8Attr, I16Attr, I32Attr, I64Attr, F16Attr, F32Attr, BF16Attr]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16ConstAttr : AnyAttrOf<[I8Attr, I16Attr, F16Attr, F32Attr, BF16Attr]>;
// TensorARM Types
@@ -38,23 +51,37 @@
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensorArm">;
-def SPIRV_Int8_TensorArm1D : TensorArmRankOf<[SPIRV_Int8], [1]>;
-def SPIRV_Int16OrInt32_TensorArm1D : TensorArmRankOf<[SPIRV_Int16, SPIRV_Int32], [1]>;
-def SPIRV_Int32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
-def SPIRV_Float32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
-def SPIRV_TosaInteger_TensorArm1D : TensorArmRankOf<[SPIRV_TosaInteger], [1]>;
-def SPIRV_TosaNumerical_TensorArm1D : TensorArmRankOf<[SPIRV_TosaNumerical], [1]>;
-def SPIRV_TosaNumerical_TensorArm3D : TensorArmRankOf<[SPIRV_TosaNumerical], [3]>;
-def SPIRV_TosaNumerical_TensorArm4D : TensorArmRankOf<[SPIRV_TosaNumerical], [4]>;
-def SPIRV_TosaNumerical_TensorArm5D : TensorArmRankOf<[SPIRV_TosaNumerical], [5]>;
+def SPIRV_I8_TensorArm1D : TensorArmRankOf<[SPIRV_Int8], [1]>;
+def SPIRV_I16OrI32_TensorArm1D : TensorArmRankOf<[SPIRV_I16OrI32], [1]>;
+def SPIRV_I32_TensorArm2D : TensorArmRankOf<[SPIRV_Int32], [2]>;
+def SPIRV_F32_TensorArm3D: TensorArmRankOf<[SPIRV_Float32], [3]>;
+def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm1D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [1]>;
+def SPIRV_I8OrI16_TensorArm1D : TensorArmRankOf<[SPIRV_I8OrI16], [1]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [3]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm3D : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [3]>;
+def SPIRV_I32OrI64OrF16OrF32_TensorArm3D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32], [3]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [4]>;
+def SPIRV_I8OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [4]>;
+def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [4]>;
+def SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16_TensorArm4D : TensorArmRankOf<[SPIRV_I32OrI8OrI64OrI16OrF16OrF32OrBF16], [4]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [5]>;
+def SPIRV_I8OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I8OrF16OrF32OrBF16], [5]>;
+def SPIRV_I32OrI64OrF16OrF32OrBF16_TensorArm5D : TensorArmRankOf<[SPIRV_I32OrI64OrF16OrF32OrBF16], [5]>;
-def SPIRV_TosaAny_TensorArm : TensorArmRankOf<[SPIRV_TosaAny], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_TosaNumerical_TensorArm : TensorArmRankOf<[SPIRV_TosaNumerical], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_TosaInteger_TensorArm : TensorArmRankOf<[SPIRV_TosaInteger], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_TosaFloat_TensorArm : TensorArmRankOf<[SPIRV_TosaFloat], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrBF16OrF16OrF32], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_F16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_F16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrBoolOrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrF16OrF32OrBF16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrI32OrI64_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32OrI64], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16OrI32_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16OrI32], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI16_TensorArm : TensorArmRankOf<[SPIRV_I8OrI16], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I8OrI32_TensorArm : TensorArmRankOf<[SPIRV_I8OrI32], [1, 2, 3, 4, 5, 6]>;
def SPIRV_Bool_TensorArm : TensorArmRankOf<[SPIRV_Bool], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_Int32_TensorArm : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5, 6]>;
-def SPIRV_Int32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
+def SPIRV_I32_TensorArm : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5, 6]>;
+def SPIRV_I32_TensorArmUpTo5D : TensorArmRankOf<[SPIRV_Int32], [1, 2, 3, 4, 5]>;
class Is1DTensorArmOfLength<list<int> allowedLengths> :
And<[HasAnyRankOfPred<[1]>,
@@ -68,21 +95,21 @@
"rank 1 tensorArm of length " # !interleave(allowedLengths, "/"),
"::mlir::spirv::TensorArmType">;
-def SPIRV_Int32_1DTensorArmOfLength2 : SPIRV_1DTensorArmOfLengthAndType<[2], [SPIRV_Int32]>;
-def SPIRV_Int32_1DTensorArmOfLength4 : SPIRV_1DTensorArmOfLengthAndType<[4], [SPIRV_Int32]>;
+def SPIRV_I32_1DTensorArmOfLength2 : SPIRV_1DTensorArmOfLengthAndType<[2], [SPIRV_Int32]>;
+def SPIRV_I32_1DTensorArmOfLength4 : SPIRV_1DTensorArmOfLengthAndType<[4], [SPIRV_Int32]>;
-def SPIRV_Int32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1, 2, 3, 4, 5, 6], [SPIRV_Int32]>;
-def SPIRV_Int32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2, 4, 6, 8, 10, 12], [SPIRV_Int32]>;
+def SPIRV_I32_1DTensorArmOfLength1To6 : SPIRV_1DTensorArmOfLengthAndType<[1, 2, 3, 4, 5, 6], [SPIRV_Int32]>;
+def SPIRV_I32_1DTensorArmOfEvenLength2To12 : SPIRV_1DTensorArmOfLengthAndType<[2, 4, 6, 8, 10, 12], [SPIRV_Int32]>;
def SPIRV_DenseElementAttrsWithTensorArmType : AttrConstraint<
CPred<"::llvm::isa<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType())">,
"Attr with type = spirv::TensorArmType">;
-def SPIRV_Int32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
-def SPIRV_Int32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
-def SPIRV_Int32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
-def SPIRV_Int32_1DTensorArmOfLength5Attr : ConfinedAttr<RankedI32ElementsAttr<[5]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
-def SPIRV_Int32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_I32_1DTensorArmOfLength2Attr : ConfinedAttr<RankedI32ElementsAttr<[2]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_I32_1DTensorArmOfLength3Attr : ConfinedAttr<RankedI32ElementsAttr<[3]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_I32_1DTensorArmOfLength4Attr : ConfinedAttr<RankedI32ElementsAttr<[4]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_I32_1DTensorArmOfLength5Attr : ConfinedAttr<RankedI32ElementsAttr<[5]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
+def SPIRV_I32_1DTensorArmOfLength6Attr : ConfinedAttr<RankedI32ElementsAttr<[6]>, [SPIRV_DenseElementAttrsWithTensorArmType]>;
class Is1DTensorArmAttrOfLength<list<int> allowedLengths> :
AttrConstraint<And<[CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape().size() == 1 }]>,
@@ -90,13 +117,16 @@
CPred<[{::llvm::cast<::mlir::spirv::TensorArmType>(::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()).getShape()[0] == }]
# allowedlength>)>]>>;
-def SPIRV_Int32_1DTensorArmOfLength1To6Attr : ConfinedAttr<
+def SPIRV_I32_1DTensorArmOfLength1To6Attr : ConfinedAttr<
I32ElementsAttr, [SPIRV_DenseElementAttrsWithTensorArmType, Is1DTensorArmAttrOfLength<[1, 2, 3, 4, 5, 6]>]>;
-def SPIRV_Int8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
-def SPIRV_TosaInteger_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaInteger]>;
-def SPIRV_TosaNumerical_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaNumerical]>;
-def SPIRV_TosaAny_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_TosaAny]>;
+def SPIRV_I8_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_Int8]>;
+def SPIRV_I8OrI16OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrF16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrF16OrF32OrBF16]>;
+def SPIRV_I8OrI16OrI32OrI64_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32OrI64]>;
+def SPIRV_I8OrI16OrI32_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrI16OrI32]>;
+def SPIRV_I8OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_I8OrF16OrF32OrBF16]>;
+def SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16_1DTensorArmOfLength1 : SPIRV_1DTensorArmOfLengthAndType<[1], [SPIRV_BoolOrI8OrI16OrI32OrF16OrF32OrBF16]>;
// Struct type
@@ -111,7 +141,7 @@
"::llvm::cast<::mlir::spirv::StructType>($_self).getElementTypes()",
"Struct">;
-def SPIRV_Struct_2_Float32_TensorArm3D : IsStructOfNumElementsAndType<2, [SPIRV_Float32_TensorArm3D]>;
+def SPIRV_Struct_2_F32_TensorArm3D : IsStructOfNumElementsAndType<2, [SPIRV_F32_TensorArm3D]>;
// Op Trait constraints:
diff --git a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
index f5a8a3c..c238de3 100644
--- a/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/tosa-ops-verification.mlir
@@ -79,27 +79,19 @@
// spirv.TOSA.Conv2D
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
- %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
- // expected-error @+1 {{op failed to verify that if input has type integer then input must have a type in [8-bit signless integer,16-bit signless integer]}}
- %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
-}
-
-spirv.ARM.Graph @conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [32-bit signless integer]}}
- %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+ %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi64>
+ spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
}
-spirv.ARM.Graph @conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+spirv.ARM.Graph @conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+ %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [64-bit signless integer]}}
- %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
}
@@ -119,11 +111,11 @@
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
}
-spirv.ARM.Graph @conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+spirv.ARM.Graph @conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
- %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.Conv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
}
@@ -163,20 +155,12 @@
// spirv.TOSA.Conv3D
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @conv3d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) {
- %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
- // expected-error @+1 {{ op failed to verify that if input has type integer then input must have a type in [8-bit signless integer,16-bit signless integer]}}
- %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi32>, !spirv.arm.tensor<7x1x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7x1xi64>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64>
-}
-
-spirv.ARM.Graph @conv3d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi16>) {
+spirv.ARM.Graph @conv3d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi64>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [32-bit signless integer]}}
- %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi16>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi16>
+ %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi64>
+ spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi64>
}
spirv.ARM.Graph @conv3d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
@@ -203,11 +187,11 @@
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11x1xf16>
}
-spirv.ARM.Graph @conv3d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
+spirv.ARM.Graph @conv3d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7x1xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
- %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32>
+ %7 = spirv.Tosa.Conv3D pad = [1, 0, 0, 0, 0, 0], stride = [1, 2, 3], dilation = [7, 1, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1x1xi8>, !spirv.arm.tensor<7x1x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7x1xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7x1xi32>
}
@@ -247,20 +231,12 @@
// spirv.TOSA.DepthwiseConv2D
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @depthwise_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
- %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
- // expected-error @+1 {{op failed to verify that if input has type integer then input must have a type in [8-bit signless integer,16-bit signless integer]}}
- %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
-}
-
-spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [32-bit signless integer]}}
- %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+ %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi64>
+ spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
}
spirv.ARM.Graph @depthwise_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
@@ -287,11 +263,11 @@
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
}
-spirv.ARM.Graph @depthwise_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+spirv.ARM.Graph @depthwise_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
- %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.DepthwiseConv2D pad = [1, 0, 0, 0], stride = [1, 2], dilation = [7, 1], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
}
@@ -331,10 +307,10 @@
// spirv.TOSA.MatMul
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @matmul_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x4x4xi8>, %arg1: !spirv.arm.tensor<1x4x4xi8>, %arg2: !spirv.arm.tensor<1xi8>, %arg3: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<1x4x4xi16>) {
+spirv.ARM.Graph @matmul_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x4x4xi8>, %arg1: !spirv.arm.tensor<1x4x4xi8>, %arg2: !spirv.arm.tensor<1xi8>, %arg3: !spirv.arm.tensor<1xi8>) -> (!spirv.arm.tensor<1x4x4xi64>) {
// expected-error @+1 {{op failed to verify that if A has type 8-bit signless integer then output must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xi8>, !spirv.arm.tensor<1x4x4xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x4xi16>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi16>
+ %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xi8>, !spirv.arm.tensor<1x4x4xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x4x4xi64>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi64>
}
spirv.ARM.Graph @matmul_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x4x4xi16>, %arg1: !spirv.arm.tensor<1x4x4xi16>, %arg2: !spirv.arm.tensor<1xi16>, %arg3: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x4x4xi32>) {
@@ -343,10 +319,10 @@
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xi32>
}
-spirv.ARM.Graph @matmul_mismatch_result_element_type_bf16_input(%arg0: !spirv.arm.tensor<1x4x4xbf16>, %arg1: !spirv.arm.tensor<1x4x4xbf16>, %arg2: !spirv.arm.tensor<1xbf16>, %arg3: !spirv.arm.tensor<1xbf16>) -> (!spirv.arm.tensor<1x4x4xbf16>) {
+spirv.ARM.Graph @matmul_mismatch_result_element_type_bf16_input(%arg0: !spirv.arm.tensor<1x4x4xbf16>, %arg1: !spirv.arm.tensor<1x4x4xbf16>, %arg2: !spirv.arm.tensor<1xbf16>, %arg3: !spirv.arm.tensor<1xbf16>) -> (!spirv.arm.tensor<1x4x4xf16>) {
// expected-error @+1 {{op failed to verify that if A has type bfloat16 type then output must have a type in [32-bit float]}}
- %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xbf16>, !spirv.arm.tensor<1x4x4xbf16>, !spirv.arm.tensor<1xbf16>, !spirv.arm.tensor<1xbf16> -> !spirv.arm.tensor<1x4x4xbf16>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xbf16>
+ %0 = spirv.Tosa.MatMul %arg0, %arg1, %arg2, %arg3 : !spirv.arm.tensor<1x4x4xbf16>, !spirv.arm.tensor<1x4x4xbf16>, !spirv.arm.tensor<1xbf16>, !spirv.arm.tensor<1xbf16> -> !spirv.arm.tensor<1x4x4xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x4x4xf16>
}
spirv.ARM.Graph @matmul_mismatch_result_element_type_f16_input(%arg0: !spirv.arm.tensor<1x4x4xf16>, %arg1: !spirv.arm.tensor<1x4x4xf16>, %arg2: !spirv.arm.tensor<1xf16>, %arg3: !spirv.arm.tensor<1xf16>) -> (!spirv.arm.tensor<1x4x4xi32>) {
@@ -393,27 +369,19 @@
// spirv.TOSA.TransposeConv2D
//===----------------------------------------------------------------------===//
-spirv.ARM.Graph @transpose_conv2d_wrong_input_integer_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi32>, %arg1: !spirv.arm.tensor<7x1x1x1xi32>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
- %5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi32>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi32>
- // expected-error @+1 {{op failed to verify that if input has type integer then input must have a type in [8-bit signless integer,16-bit signless integer]}}
- %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi32>, !spirv.arm.tensor<7x1x1x1xi32>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi32> -> !spirv.arm.tensor<1x65536x2x7xi64>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
-}
-
-spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi16>) {
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i8_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi64>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [32-bit signless integer]}}
- %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi16>
- spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi16>
+ %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi64>
+ spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi64>
}
-spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi16>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+spirv.ARM.Graph @transpose_conv2d_mismatch_result_element_type_i16_input(%arg0: !spirv.arm.tensor<1x65535x3x1xi16>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi32>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi16>
- %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi16>
+ %6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [64-bit signless integer]}}
- %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi16>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi16> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi16>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
}
@@ -433,11 +401,11 @@
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x34x18x11xf16>
}
-spirv.ARM.Graph @transpose_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi16>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
+spirv.ARM.Graph @transpose_conv2d_bias_element_type_must_be_same_as_result_element_type(%arg0: !spirv.arm.tensor<1x65535x3x1xi8>, %arg1: !spirv.arm.tensor<7x1x1x1xi8>, %arg2: !spirv.arm.tensor<1xi64>) -> (!spirv.arm.tensor<1x65536x2x7xi32>) {
%5 = spirv.Constant dense<35> : !spirv.arm.tensor<1xi8>
%6 = spirv.Constant dense<57> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that all of {bias, output} have same element type}}
- %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
+ %7 = spirv.Tosa.TransposeConv2D out_pad = [1, 0, 0, 0], stride = [1, 2], acc_type = <INT32>, local_bound = false, %arg0, %arg1, %arg2, %5, %6 : !spirv.arm.tensor<1x65535x3x1xi8>, !spirv.arm.tensor<7x1x1x1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<1x65536x2x7xi32>
spirv.ARM.GraphOutputs %7 : !spirv.arm.tensor<1x65536x2x7xi32>
}
@@ -547,16 +515,16 @@
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
}
-spirv.ARM.Graph @add_input_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xi16>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi16>) {
+spirv.ARM.Graph @add_input_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xf16>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xf16>) {
// expected-error @+1 {{op failed to verify that all of {input1, input2} have same element type}}
- %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi16>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi16>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi16>
+ %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xf16>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xf16>
}
-spirv.ARM.Graph @add_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi16>) {
+spirv.ARM.Graph @add_input_output_element_types_not_matching(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xf16>) {
// expected-error @+1 {{op failed to verify that all of {input1, output} have same element type}}
- %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi16>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi16>
+ %0 = spirv.Tosa.Add %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xf16>
+ spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xf16>
}
spirv.ARM.Graph @add_inputs_not_broadcastable(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<2x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
@@ -945,24 +913,6 @@
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x6x6xi32>
}
-spirv.ARM.Graph @maximum_integer_input1_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi8>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input1 has type integer then input1 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Maximum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi8>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @maximum_integer_input2_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi8>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input2 has type integer then input2 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Maximum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi8> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @maximum_integer_output_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi8>) {
- // expected-error @+1 {{op failed to verify that if output has type integer then output must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Maximum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi8>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi8>
-}
-
//===----------------------------------------------------------------------===//
// spirv.TOSA.Minimum
//===----------------------------------------------------------------------===//
@@ -997,24 +947,6 @@
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x6x6xi32>
}
-spirv.ARM.Graph @minimum_integer_input1_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi8>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input1 has type integer then input1 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Minimum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi8>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @minimum_integer_input2_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi8>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input2 has type integer then input2 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Minimum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi8> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @minimum_integer_output_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi8>) {
- // expected-error @+1 {{op failed to verify that if output has type integer then output must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Minimum nan_mode = <Propagate>, %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi8>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi8>
-}
-
//===----------------------------------------------------------------------===//
// spirv.TOSA.Mul
//===----------------------------------------------------------------------===//
@@ -1047,11 +979,11 @@
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<34x21x1xi32>
}
-spirv.ARM.Graph @mul_ouput_must_have_i32_as_element_type(%arg0: !spirv.arm.tensor<34x21x39xi8>, %arg1: !spirv.arm.tensor<34x21x1xi8>) -> (!spirv.arm.tensor<34x21x39xi16>) {
+spirv.ARM.Graph @mul_ouput_must_have_i32_as_element_type(%arg0: !spirv.arm.tensor<34x21x39xi8>, %arg1: !spirv.arm.tensor<34x21x1xi8>) -> (!spirv.arm.tensor<34x21x39xf16>) {
%0 = spirv.Constant dense<31> : !spirv.arm.tensor<1xi8>
// expected-error @+1 {{op failed to verify that if input1 has type integer then output must have a type in [32-bit signless integer]}}
- %1 = spirv.Tosa.Mul %arg0, %arg1, %0 : !spirv.arm.tensor<34x21x39xi8>, !spirv.arm.tensor<34x21x1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<34x21x39xi16>
- spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<34x21x39xi16>
+ %1 = spirv.Tosa.Mul %arg0, %arg1, %0 : !spirv.arm.tensor<34x21x39xi8>, !spirv.arm.tensor<34x21x1xi8>, !spirv.arm.tensor<1xi8> -> !spirv.arm.tensor<34x21x39xf16>
+ spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<34x21x39xf16>
}
spirv.ARM.Graph @mul_input_with_element_type_f16_must_produce_an_output_with_element_type_f16(%arg0: !spirv.arm.tensor<57x1x55xf16>, %arg1: !spirv.arm.tensor<57x37x55xf16>) -> (!spirv.arm.tensor<57x37x55xf32>) {
@@ -1143,24 +1075,6 @@
spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<1x10x6x6xi32>
}
-spirv.ARM.Graph @sub_integer_input1_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi8>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input1 has type integer then input1 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Sub %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi8>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @sub_integer_input2_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi8>) -> (!spirv.arm.tensor<6x10x6x6xi32>) {
- // expected-error @+1 {{op failed to verify that if input2 has type integer then input2 must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Sub %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi8> -> !spirv.arm.tensor<6x10x6x6xi32>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi32>
-}
-
-spirv.ARM.Graph @sub_integer_output_must_be_i32(%arg0: !spirv.arm.tensor<6x10x6x6xi32>, %arg1: !spirv.arm.tensor<1x10x6x6xi32>) -> (!spirv.arm.tensor<6x10x6x6xi8>) {
- // expected-error @+1 {{op failed to verify that if output has type integer then output must have a type in [32-bit signless integer]}}
- %0 = spirv.Tosa.Sub %arg0, %arg1 : !spirv.arm.tensor<6x10x6x6xi32>, !spirv.arm.tensor<1x10x6x6xi32> -> !spirv.arm.tensor<6x10x6x6xi8>
- spirv.ARM.GraphOutputs %0 : !spirv.arm.tensor<6x10x6x6xi8>
-}
-
//===----------------------------------------------------------------------===//
// spirv.TOSA.Table
//===----------------------------------------------------------------------===//
@@ -1193,18 +1107,18 @@
spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<3x2x15x7xi32>
}
-spirv.ARM.Graph @table_input_with_element_type_i8_requires_an_output_with_element_type_i8(%arg0: !spirv.arm.tensor<3x2x15x7xi8>) -> (!spirv.arm.tensor<3x2x15x7xi16>) {
+spirv.ARM.Graph @table_input_with_element_type_i8_requires_an_output_with_element_type_i8(%arg0: !spirv.arm.tensor<3x2x15x7xi8>) -> (!spirv.arm.tensor<3x2x15x7xi32>) {
%0 = spirv.ARM.GraphConstant {graph_constant_id = 0 : i32} : !spirv.arm.tensor<256xi8>
// expected-error @+1 {{op failed to verify that if input1 has type 8-bit signless integer then output must have a type in [8-bit signless integer]}}
- %1 = spirv.Tosa.Table %arg0, %0 : !spirv.arm.tensor<3x2x15x7xi8>, !spirv.arm.tensor<256xi8> -> !spirv.arm.tensor<3x2x15x7xi16>
- spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<3x2x15x7xi16>
+ %1 = spirv.Tosa.Table %arg0, %0 : !spirv.arm.tensor<3x2x15x7xi8>, !spirv.arm.tensor<256xi8> -> !spirv.arm.tensor<3x2x15x7xi32>
+ spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<3x2x15x7xi32>
}
-spirv.ARM.Graph @table_input_with_element_type_i16_requires_an_output_with_element_type_i32(%arg0: !spirv.arm.tensor<3x2x15x7xi16>) -> (!spirv.arm.tensor<3x2x15x7xi16>) {
+spirv.ARM.Graph @table_input_with_element_type_i16_requires_an_output_with_element_type_i32(%arg0: !spirv.arm.tensor<3x2x15x7xi16>) -> (!spirv.arm.tensor<3x2x15x7xi8>) {
%0 = spirv.ARM.GraphConstant {graph_constant_id = 0 : i32} : !spirv.arm.tensor<513xi16>
// expected-error @+1 {{op failed to verify that if input1 has type 16-bit signless integer then output must have a type in [32-bit signless integer]}}
- %1 = spirv.Tosa.Table %arg0, %0 : !spirv.arm.tensor<3x2x15x7xi16>, !spirv.arm.tensor<513xi16> -> !spirv.arm.tensor<3x2x15x7xi16>
- spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<3x2x15x7xi16>
+ %1 = spirv.Tosa.Table %arg0, %0 : !spirv.arm.tensor<3x2x15x7xi16>, !spirv.arm.tensor<513xi16> -> !spirv.arm.tensor<3x2x15x7xi8>
+ spirv.ARM.GraphOutputs %1 : !spirv.arm.tensor<3x2x15x7xi8>
}
//===----------------------------------------------------------------------===//
@@ -2111,46 +2025,6 @@
spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi16>
}
-spirv.ARM.Graph @rescale_i8_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi8>) -> (!spirv.arm.tensor<2x3x4xi64>) {
- %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
- %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
- %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
- %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
- // expected-error @+1 {{op failed to verify that if input has type 8-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
- %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
- spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
-}
-
-spirv.ARM.Graph @rescale_i16_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi16>) -> (!spirv.arm.tensor<2x3x4xi64>) {
- %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
- %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
- %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi16>
- %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
- // expected-error @+1 {{op failed to verify that if input has type 16-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
- %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi16>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
- spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
-}
-
-spirv.ARM.Graph @rescale_i32_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi64>) {
- %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
- %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
- %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi32>
- %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
- // expected-error @+1 {{op failed to verify that if input has type 32-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
- %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi32>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi32>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
- spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
-}
-
-spirv.ARM.Graph @rescale_i64_input_requires_i8_i16_or_i32_output(%arg0: !spirv.arm.tensor<2x3x4xi64>) -> (!spirv.arm.tensor<2x3x4xi64>) {
- %1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
- %2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
- %3 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
- %4 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi64>
- // expected-error @+1 {{op failed to verify that if input has type 64-bit signless integer then output must have a type in [8-bit signless integer,16-bit signless integer,32-bit signless integer]}}
- %5 = spirv.Tosa.Rescale scale32 = false, rounding_mode = <SingleRound>, per_channel = false, input_unsigned = false, output_unsigned = false, %arg0, %1, %2, %3, %4 : !spirv.arm.tensor<2x3x4xi64>, !spirv.arm.tensor<1xi16>, !spirv.arm.tensor<1xi8>, !spirv.arm.tensor<1xi64>, !spirv.arm.tensor<1xi64> -> !spirv.arm.tensor<2x3x4xi64>
- spirv.ARM.GraphOutputs %5 : !spirv.arm.tensor<2x3x4xi64>
-}
-
spirv.ARM.Graph @rescale_input_unsigned_true_requires_i8_or_i16_input(%arg0: !spirv.arm.tensor<2x3x4xi32>) -> (!spirv.arm.tensor<2x3x4xi16>) {
%1 = spirv.Constant dense<1> : !spirv.arm.tensor<1xi16>
%2 = spirv.Constant dense<0> : !spirv.arm.tensor<1xi8>
diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
index 8caeb07..4de8c60 100644
--- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
+++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp
@@ -562,7 +562,7 @@
os << tabs << formatv(" {0}.push_back(attrTypeID);\n", operandList);
} else if (llvm::is_contained({"SPIRV_BoolConstAttr",
"SPIRV_TensorArmAxisAttr",
- "SPIRV_TosaNumericalAttr"},
+ "SPIRV_I8OrI16OrF16OrF32OrBF16ConstAttr"},
attr.getAttrDefName())) {
os << tabs
<< formatv(
@@ -867,9 +867,9 @@
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"TypeAttr::get(getType({2}[{3}++]))));\n",
attrList, attrName, words, wordIndex);
- } else if (llvm::is_contained(
- {"SPIRV_BoolConstAttr", "SPIRV_TosaNumericalAttr"},
- attr.getAttrDefName()) ||
+ } else if (llvm::is_contained({"SPIRV_BoolConstAttr",
+ "SPIRV_I8OrI16OrF16OrF32OrBF16ConstAttr"},
+ attr.getAttrDefName()) ||
attr.getAttrDefName().contains("TensorArm")) {
os << tabs
<< formatv("std::optional<std::pair<Attribute, Type>> c = "