[mlir][vector] Fix parser of `vector.contract` (#133434)
This PR adds a check in the parser to prevent a crash when
`vector.contract` lacks the `iterator_types` attribute.
Fixes #132886.
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eccb3e5..5a39836 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -787,8 +787,13 @@
// because tests still use the old format when 'iterator_types' attribute is
// represented as an array of strings.
// TODO: Remove this conversion once tests are fixed.
- ArrayAttr iteratorTypes = llvm::cast<ArrayAttr>(
+ auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
result.attributes.get(getIteratorTypesAttrName(result.name)));
+ if (!iteratorTypes) {
+ return parser.emitError(loc)
+ << "expected " << getIteratorTypesAttrName(result.name)
+ << " array attribute";
+ }
SmallVector<Attribute> iteratorTypeAttrs;
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 1b89e8e..ea6d002 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -1015,6 +1015,14 @@
// -----
+func.func @contract_missing_iterator_types(%arg0: vector<1x2xi32>, %arg1: vector<2xi32>, %arg2: vector<1xi32>) -> vector<1xi32> {
+ // expected-error@+1 {{'vector.contract' expected "iterator_types" array attribute}}
+ %0 = vector.contract {} %arg0, %arg1, %arg2 : vector<1x2xi32>, vector<2xi32> into vector<1xi32>
+ return %0 : vector<1xi32>
+}
+
+// -----
+
func.func @create_mask_0d_no_operands() {
%c1 = arith.constant 1 : index
// expected-error@+1 {{must specify exactly one operand for 0-D create_mask}}