| //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the parser for the MLIR Types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "Parser.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/BuiltinAttributeInterfaces.h" |
| #include "mlir/IR/BuiltinTypeInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/OpDefinition.h" |
| #include "mlir/IR/TensorEncoding.h" |
| #include "mlir/IR/Types.h" |
| #include "mlir/Support/LLVM.h" |
| #include <cassert> |
| #include <cstdint> |
| #include <limits> |
| #include <optional> |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| |
| /// Optionally parse a type. |
| OptionalParseResult Parser::parseOptionalType(Type &type) { |
| // There are many different starting tokens for a type, check them here. |
| switch (getToken().getKind()) { |
| case Token::l_paren: |
| case Token::kw_memref: |
| case Token::kw_tensor: |
| case Token::kw_complex: |
| case Token::kw_tuple: |
| case Token::kw_vector: |
| case Token::inttype: |
| case Token::kw_f4E2M1FN: |
| case Token::kw_f6E2M3FN: |
| case Token::kw_f6E3M2FN: |
| case Token::kw_f8E5M2: |
| case Token::kw_f8E4M3: |
| case Token::kw_f8E4M3FN: |
| case Token::kw_f8E5M2FNUZ: |
| case Token::kw_f8E4M3FNUZ: |
| case Token::kw_f8E4M3B11FNUZ: |
| case Token::kw_f8E3M4: |
| case Token::kw_f8E8M0FNU: |
| case Token::kw_bf16: |
| case Token::kw_f16: |
| case Token::kw_tf32: |
| case Token::kw_f32: |
| case Token::kw_f64: |
| case Token::kw_f80: |
| case Token::kw_f128: |
| case Token::kw_index: |
| case Token::kw_none: |
| case Token::exclamation_identifier: |
| return failure(!(type = parseType())); |
| |
| default: |
| return std::nullopt; |
| } |
| } |
| |
| /// Parse an arbitrary type. |
| /// |
| /// type ::= function-type |
| /// | non-function-type |
| /// |
| Type Parser::parseType() { |
| if (getToken().is(Token::l_paren)) |
| return parseFunctionType(); |
| return parseNonFunctionType(); |
| } |
| |
| /// Parse a function result type. |
| /// |
| /// function-result-type ::= type-list-parens |
| /// | non-function-type |
| /// |
| ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { |
| if (getToken().is(Token::l_paren)) |
| return parseTypeListParens(elements); |
| |
| Type t = parseNonFunctionType(); |
| if (!t) |
| return failure(); |
| elements.push_back(t); |
| return success(); |
| } |
| |
| /// Parse a list of types without an enclosing parenthesis. The list must have |
| /// at least one member. |
| /// |
| /// type-list-no-parens ::= type (`,` type)* |
| /// |
| ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { |
| auto parseElt = [&]() -> ParseResult { |
| auto elt = parseType(); |
| elements.push_back(elt); |
| return elt ? success() : failure(); |
| }; |
| |
| return parseCommaSeparatedList(parseElt); |
| } |
| |
| /// Parse a parenthesized list of types. |
| /// |
| /// type-list-parens ::= `(` `)` |
| /// | `(` type-list-no-parens `)` |
| /// |
| ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { |
| if (parseToken(Token::l_paren, "expected '('")) |
| return failure(); |
| |
| // Handle empty lists. |
| if (getToken().is(Token::r_paren)) |
| return consumeToken(), success(); |
| |
| if (parseTypeListNoParens(elements) || |
| parseToken(Token::r_paren, "expected ')'")) |
| return failure(); |
| return success(); |
| } |
| |
| /// Parse a complex type. |
| /// |
| /// complex-type ::= `complex` `<` type `>` |
| /// |
| Type Parser::parseComplexType() { |
| consumeToken(Token::kw_complex); |
| |
| // Parse the '<'. |
| if (parseToken(Token::less, "expected '<' in complex type")) |
| return nullptr; |
| |
| SMLoc elementTypeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType || |
| parseToken(Token::greater, "expected '>' in complex type")) |
| return nullptr; |
| if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType)) |
| return emitError(elementTypeLoc, "invalid element type for complex"), |
| nullptr; |
| |
| return ComplexType::get(elementType); |
| } |
| |
| /// Parse a function type. |
| /// |
| /// function-type ::= type-list-parens `->` function-result-type |
| /// |
| Type Parser::parseFunctionType() { |
| assert(getToken().is(Token::l_paren)); |
| |
| SmallVector<Type, 4> arguments, results; |
| if (parseTypeListParens(arguments) || |
| parseToken(Token::arrow, "expected '->' in function type") || |
| parseFunctionResultTypes(results)) |
| return nullptr; |
| |
| return builder.getFunctionType(arguments, results); |
| } |
| |
| /// Parse a memref type. |
| /// |
| /// memref-type ::= ranked-memref-type | unranked-memref-type |
| /// |
| /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type |
| /// (`,` layout-specification)? (`,` memory-space)? `>` |
| /// |
| /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` |
| /// |
| /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` |
| /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list |
| /// layout-specification ::= semi-affine-map | strided-layout | attribute |
| /// memory-space ::= integer-literal | attribute |
| /// |
| Type Parser::parseMemRefType() { |
| SMLoc loc = getToken().getLoc(); |
| consumeToken(Token::kw_memref); |
| |
| if (parseToken(Token::less, "expected '<' in memref type")) |
| return nullptr; |
| |
| bool isUnranked; |
| SmallVector<int64_t, 4> dimensions; |
| |
| if (consumeIf(Token::star)) { |
| // This is an unranked memref type. |
| isUnranked = true; |
| if (parseXInDimensionList()) |
| return nullptr; |
| |
| } else { |
| isUnranked = false; |
| if (parseDimensionListRanked(dimensions)) |
| return nullptr; |
| } |
| |
| // Parse the element type. |
| auto typeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| if (!elementType) |
| return nullptr; |
| |
| // Check that memref is formed from allowed types. |
| if (!BaseMemRefType::isValidElementType(elementType)) |
| return emitError(typeLoc, "invalid memref element type"), nullptr; |
| |
| MemRefLayoutAttrInterface layout; |
| Attribute memorySpace; |
| |
| auto parseElt = [&]() -> ParseResult { |
| // Either it is MemRefLayoutAttrInterface or memory space attribute. |
| Attribute attr = parseAttribute(); |
| if (!attr) |
| return failure(); |
| |
| if (isa<MemRefLayoutAttrInterface>(attr)) { |
| layout = cast<MemRefLayoutAttrInterface>(attr); |
| } else if (memorySpace) { |
| return emitError("multiple memory spaces specified in memref type"); |
| } else { |
| memorySpace = attr; |
| return success(); |
| } |
| |
| if (isUnranked) |
| return emitError("cannot have affine map for unranked memref type"); |
| if (memorySpace) |
| return emitError("expected memory space to be last in memref type"); |
| |
| return success(); |
| }; |
| |
| // Parse a list of mappings and address space if present. |
| if (!consumeIf(Token::greater)) { |
| // Parse comma separated list of affine maps, followed by memory space. |
| if (parseToken(Token::comma, "expected ',' or '>' in memref type") || |
| parseCommaSeparatedListUntil(Token::greater, parseElt, |
| /*allowEmptyList=*/false)) { |
| return nullptr; |
| } |
| } |
| |
| if (isUnranked) |
| return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); |
| |
| return getChecked<MemRefType>(loc, dimensions, elementType, layout, |
| memorySpace); |
| } |
| |
| /// Parse any type except the function type. |
| /// |
| /// non-function-type ::= integer-type |
| /// | index-type |
| /// | float-type |
| /// | extended-type |
| /// | vector-type |
| /// | tensor-type |
| /// | memref-type |
| /// | complex-type |
| /// | tuple-type |
| /// | none-type |
| /// |
| /// index-type ::= `index` |
| /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` |
| /// none-type ::= `none` |
| /// |
| Type Parser::parseNonFunctionType() { |
| switch (getToken().getKind()) { |
| default: |
| return (emitWrongTokenError("expected non-function type"), nullptr); |
| case Token::kw_memref: |
| return parseMemRefType(); |
| case Token::kw_tensor: |
| return parseTensorType(); |
| case Token::kw_complex: |
| return parseComplexType(); |
| case Token::kw_tuple: |
| return parseTupleType(); |
| case Token::kw_vector: |
| return parseVectorType(); |
| // integer-type |
| case Token::inttype: { |
| auto width = getToken().getIntTypeBitwidth(); |
| if (!width.has_value()) |
| return (emitError("invalid integer width"), nullptr); |
| if (*width > IntegerType::kMaxWidth) { |
| emitError(getToken().getLoc(), "integer bitwidth is limited to ") |
| << IntegerType::kMaxWidth << " bits"; |
| return nullptr; |
| } |
| |
| IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; |
| if (std::optional<bool> signedness = getToken().getIntTypeSignedness()) |
| signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; |
| |
| consumeToken(Token::inttype); |
| return IntegerType::get(getContext(), *width, signSemantics); |
| } |
| |
| // float-type |
| case Token::kw_f4E2M1FN: |
| consumeToken(Token::kw_f4E2M1FN); |
| return builder.getType<Float4E2M1FNType>(); |
| case Token::kw_f6E2M3FN: |
| consumeToken(Token::kw_f6E2M3FN); |
| return builder.getType<Float6E2M3FNType>(); |
| case Token::kw_f6E3M2FN: |
| consumeToken(Token::kw_f6E3M2FN); |
| return builder.getType<Float6E3M2FNType>(); |
| case Token::kw_f8E5M2: |
| consumeToken(Token::kw_f8E5M2); |
| return builder.getType<Float8E5M2Type>(); |
| case Token::kw_f8E4M3: |
| consumeToken(Token::kw_f8E4M3); |
| return builder.getType<Float8E4M3Type>(); |
| case Token::kw_f8E4M3FN: |
| consumeToken(Token::kw_f8E4M3FN); |
| return builder.getType<Float8E4M3FNType>(); |
| case Token::kw_f8E5M2FNUZ: |
| consumeToken(Token::kw_f8E5M2FNUZ); |
| return builder.getType<Float8E5M2FNUZType>(); |
| case Token::kw_f8E4M3FNUZ: |
| consumeToken(Token::kw_f8E4M3FNUZ); |
| return builder.getType<Float8E4M3FNUZType>(); |
| case Token::kw_f8E4M3B11FNUZ: |
| consumeToken(Token::kw_f8E4M3B11FNUZ); |
| return builder.getType<Float8E4M3B11FNUZType>(); |
| case Token::kw_f8E3M4: |
| consumeToken(Token::kw_f8E3M4); |
| return builder.getType<Float8E3M4Type>(); |
| case Token::kw_f8E8M0FNU: |
| consumeToken(Token::kw_f8E8M0FNU); |
| return builder.getType<Float8E8M0FNUType>(); |
| case Token::kw_bf16: |
| consumeToken(Token::kw_bf16); |
| return builder.getType<BFloat16Type>(); |
| case Token::kw_f16: |
| consumeToken(Token::kw_f16); |
| return builder.getType<Float16Type>(); |
| case Token::kw_tf32: |
| consumeToken(Token::kw_tf32); |
| return builder.getType<FloatTF32Type>(); |
| case Token::kw_f32: |
| consumeToken(Token::kw_f32); |
| return builder.getType<Float32Type>(); |
| case Token::kw_f64: |
| consumeToken(Token::kw_f64); |
| return builder.getType<Float64Type>(); |
| case Token::kw_f80: |
| consumeToken(Token::kw_f80); |
| return builder.getType<Float80Type>(); |
| case Token::kw_f128: |
| consumeToken(Token::kw_f128); |
| return builder.getType<Float128Type>(); |
| |
| // index-type |
| case Token::kw_index: |
| consumeToken(Token::kw_index); |
| return builder.getIndexType(); |
| |
| // none-type |
| case Token::kw_none: |
| consumeToken(Token::kw_none); |
| return builder.getNoneType(); |
| |
| // extended type |
| case Token::exclamation_identifier: |
| return parseExtendedType(); |
| |
| // Handle completion of a dialect type. |
| case Token::code_complete: |
| if (getToken().isCodeCompletionFor(Token::exclamation_identifier)) |
| return parseExtendedType(); |
| return codeCompleteType(); |
| } |
| } |
| |
| /// Parse a tensor type. |
| /// |
| /// tensor-type ::= `tensor` `<` dimension-list type `>` |
| /// dimension-list ::= dimension-list-ranked | `*x` |
| /// |
| Type Parser::parseTensorType() { |
| consumeToken(Token::kw_tensor); |
| |
| if (parseToken(Token::less, "expected '<' in tensor type")) |
| return nullptr; |
| |
| bool isUnranked; |
| SmallVector<int64_t, 4> dimensions; |
| |
| if (consumeIf(Token::star)) { |
| // This is an unranked tensor type. |
| isUnranked = true; |
| |
| if (parseXInDimensionList()) |
| return nullptr; |
| |
| } else { |
| isUnranked = false; |
| if (parseDimensionListRanked(dimensions)) |
| return nullptr; |
| } |
| |
| // Parse the element type. |
| auto elementTypeLoc = getToken().getLoc(); |
| auto elementType = parseType(); |
| |
| // Parse an optional encoding attribute. |
| Attribute encoding; |
| if (consumeIf(Token::comma)) { |
| auto parseResult = parseOptionalAttribute(encoding); |
| if (parseResult.has_value()) { |
| if (failed(parseResult.value())) |
| return nullptr; |
| if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) { |
| if (failed(v.verifyEncoding(dimensions, elementType, |
| [&] { return emitError(); }))) |
| return nullptr; |
| } |
| } |
| } |
| |
| if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) |
| return nullptr; |
| if (!TensorType::isValidElementType(elementType)) |
| return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; |
| |
| if (isUnranked) { |
| if (encoding) |
| return emitError("cannot apply encoding to unranked tensor"), nullptr; |
| return UnrankedTensorType::get(elementType); |
| } |
| return RankedTensorType::get(dimensions, elementType, encoding); |
| } |
| |
| /// Parse a tuple type. |
| /// |
| /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` |
| /// |
| Type Parser::parseTupleType() { |
| consumeToken(Token::kw_tuple); |
| |
| // Parse the '<'. |
| if (parseToken(Token::less, "expected '<' in tuple type")) |
| return nullptr; |
| |
| // Check for an empty tuple by directly parsing '>'. |
| if (consumeIf(Token::greater)) |
| return TupleType::get(getContext()); |
| |
| // Parse the element types and the '>'. |
| SmallVector<Type, 4> types; |
| if (parseTypeListNoParens(types) || |
| parseToken(Token::greater, "expected '>' in tuple type")) |
| return nullptr; |
| |
| return TupleType::get(getContext(), types); |
| } |
| |
| /// Parse a vector type. |
| /// |
| /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` |
| /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? |
| /// static-dim-list ::= decimal-literal (`x` decimal-literal)* |
| /// |
| VectorType Parser::parseVectorType() { |
| SMLoc loc = getToken().getLoc(); |
| consumeToken(Token::kw_vector); |
| |
| if (parseToken(Token::less, "expected '<' in vector type")) |
| return nullptr; |
| |
| // Parse the dimensions. |
| SmallVector<int64_t, 4> dimensions; |
| SmallVector<bool, 4> scalableDims; |
| if (parseVectorDimensionList(dimensions, scalableDims)) |
| return nullptr; |
| |
| // Parse the element type. |
| auto elementType = parseType(); |
| if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) |
| return nullptr; |
| |
| return getChecked<VectorType>(loc, dimensions, elementType, scalableDims); |
| } |
| |
| /// Parse a dimension list in a vector type. This populates the dimension list. |
| /// For i-th dimension, `scalableDims[i]` contains either: |
| /// * `false` for a non-scalable dimension (e.g. `4`), |
| /// * `true` for a scalable dimension (e.g. `[4]`). |
| /// |
| /// vector-dim-list := (static-dim-list `x`)? |
| /// static-dim-list ::= static-dim (`x` static-dim)* |
| /// static-dim ::= (decimal-literal | `[` decimal-literal `]`) |
| /// |
| ParseResult |
| Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions, |
| SmallVectorImpl<bool> &scalableDims) { |
| // If there is a set of fixed-length dimensions, consume it |
| while (getToken().is(Token::integer) || getToken().is(Token::l_square)) { |
| int64_t value; |
| bool scalable = consumeIf(Token::l_square); |
| if (parseIntegerInDimensionList(value)) |
| return failure(); |
| dimensions.push_back(value); |
| if (scalable) { |
| if (!consumeIf(Token::r_square)) |
| return emitWrongTokenError("missing ']' closing scalable dimension"); |
| } |
| scalableDims.push_back(scalable); |
| // Make sure we have an 'x' or something like 'xbf32'. |
| if (parseXInDimensionList()) |
| return failure(); |
| } |
| |
| return success(); |
| } |
| |
| /// Parse a dimension list of a tensor or memref type. This populates the |
| /// dimension list, using ShapedType::kDynamic for the `?` dimensions if |
| /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing |
| /// `x` is configurable. |
| /// |
| /// dimension-list ::= eps | dimension (`x` dimension)* |
| /// dimension-list-with-trailing-x ::= (dimension `x`)* |
| /// dimension ::= `?` | decimal-literal |
| /// |
| /// When `allowDynamic` is not set, this is used to parse: |
| /// |
| /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
| /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
| ParseResult |
| Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, |
| bool allowDynamic, bool withTrailingX) { |
| auto parseDim = [&]() -> LogicalResult { |
| auto loc = getToken().getLoc(); |
| if (consumeIf(Token::question)) { |
| if (!allowDynamic) |
| return emitError(loc, "expected static shape"); |
| dimensions.push_back(ShapedType::kDynamic); |
| } else { |
| int64_t value; |
| if (failed(parseIntegerInDimensionList(value))) |
| return failure(); |
| dimensions.push_back(value); |
| } |
| return success(); |
| }; |
| |
| if (withTrailingX) { |
| while (getToken().isAny(Token::integer, Token::question)) { |
| if (failed(parseDim()) || failed(parseXInDimensionList())) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| if (getToken().isAny(Token::integer, Token::question)) { |
| if (failed(parseDim())) |
| return failure(); |
| while (getToken().is(Token::bare_identifier) && |
| getTokenSpelling()[0] == 'x') { |
| if (failed(parseXInDimensionList()) || failed(parseDim())) |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| ParseResult Parser::parseIntegerInDimensionList(int64_t &value) { |
| // Hexadecimal integer literals (starting with `0x`) are not allowed in |
| // aggregate type declarations. Therefore, `0xf32` should be processed as |
| // a sequence of separate elements `0`, `x`, `f32`. |
| if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { |
| // We can get here only if the token is an integer literal. Hexadecimal |
| // integer literals can only start with `0x` (`1x` wouldn't lex as a |
| // literal, just `1` would, at which point we don't get into this |
| // branch). |
| assert(getTokenSpelling()[0] == '0' && "invalid integer literal"); |
| value = 0; |
| state.lex.resetPointer(getTokenSpelling().data() + 1); |
| consumeToken(); |
| } else { |
| // Make sure this integer value is in bound and valid. |
| std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue(); |
| if (!dimension || |
| *dimension > (uint64_t)std::numeric_limits<int64_t>::max()) |
| return emitError("invalid dimension"); |
| value = (int64_t)*dimension; |
| consumeToken(Token::integer); |
| } |
| return success(); |
| } |
| |
| /// Parse an 'x' token in a dimension list, handling the case where the x is |
| /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next |
| /// token. |
| ParseResult Parser::parseXInDimensionList() { |
| if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') |
| return emitWrongTokenError("expected 'x' in dimension list"); |
| |
| // If we had a prefix of 'x', lex the next token immediately after the 'x'. |
| if (getTokenSpelling().size() != 1) |
| state.lex.resetPointer(getTokenSpelling().data() + 1); |
| |
| // Consume the 'x'. |
| consumeToken(Token::bare_identifier); |
| |
| return success(); |
| } |