blob: a461ebed967a8a5c35ae097ac067c9ce13624aa1 [file] [log] [blame]
//===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the parser for the MLIR Types.
//
//===----------------------------------------------------------------------===//
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include <cassert>
#include <cstdint>
#include <limits>
#include <optional>
using namespace mlir;
using namespace mlir::detail;
/// Optionally parse a type.
OptionalParseResult Parser::parseOptionalType(Type &type) {
// There are many different starting tokens for a type, check them here.
switch (getToken().getKind()) {
case Token::l_paren:
case Token::kw_memref:
case Token::kw_tensor:
case Token::kw_complex:
case Token::kw_tuple:
case Token::kw_vector:
case Token::inttype:
case Token::kw_f4E2M1FN:
case Token::kw_f6E2M3FN:
case Token::kw_f6E3M2FN:
case Token::kw_f8E5M2:
case Token::kw_f8E4M3:
case Token::kw_f8E4M3FN:
case Token::kw_f8E5M2FNUZ:
case Token::kw_f8E4M3FNUZ:
case Token::kw_f8E4M3B11FNUZ:
case Token::kw_f8E3M4:
case Token::kw_f8E8M0FNU:
case Token::kw_bf16:
case Token::kw_f16:
case Token::kw_tf32:
case Token::kw_f32:
case Token::kw_f64:
case Token::kw_f80:
case Token::kw_f128:
case Token::kw_index:
case Token::kw_none:
case Token::exclamation_identifier:
return failure(!(type = parseType()));
default:
return std::nullopt;
}
}
/// Parse an arbitrary type.
///
/// type ::= function-type
/// | non-function-type
///
Type Parser::parseType() {
if (getToken().is(Token::l_paren))
return parseFunctionType();
return parseNonFunctionType();
}
/// Parse a function result type.
///
/// function-result-type ::= type-list-parens
/// | non-function-type
///
ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
if (getToken().is(Token::l_paren))
return parseTypeListParens(elements);
Type t = parseNonFunctionType();
if (!t)
return failure();
elements.push_back(t);
return success();
}
/// Parse a list of types without an enclosing parenthesis. The list must have
/// at least one member.
///
/// type-list-no-parens ::= type (`,` type)*
///
ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
auto parseElt = [&]() -> ParseResult {
auto elt = parseType();
elements.push_back(elt);
return elt ? success() : failure();
};
return parseCommaSeparatedList(parseElt);
}
/// Parse a parenthesized list of types.
///
/// type-list-parens ::= `(` `)`
/// | `(` type-list-no-parens `)`
///
ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
if (parseToken(Token::l_paren, "expected '('"))
return failure();
// Handle empty lists.
if (getToken().is(Token::r_paren))
return consumeToken(), success();
if (parseTypeListNoParens(elements) ||
parseToken(Token::r_paren, "expected ')'"))
return failure();
return success();
}
/// Parse a complex type.
///
/// complex-type ::= `complex` `<` type `>`
///
Type Parser::parseComplexType() {
consumeToken(Token::kw_complex);
// Parse the '<'.
if (parseToken(Token::less, "expected '<' in complex type"))
return nullptr;
SMLoc elementTypeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType ||
parseToken(Token::greater, "expected '>' in complex type"))
return nullptr;
if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
return emitError(elementTypeLoc, "invalid element type for complex"),
nullptr;
return ComplexType::get(elementType);
}
/// Parse a function type.
///
/// function-type ::= type-list-parens `->` function-result-type
///
Type Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
SmallVector<Type, 4> arguments, results;
if (parseTypeListParens(arguments) ||
parseToken(Token::arrow, "expected '->' in function type") ||
parseFunctionResultTypes(results))
return nullptr;
return builder.getFunctionType(arguments, results);
}
/// Parse a memref type.
///
/// memref-type ::= ranked-memref-type | unranked-memref-type
///
/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
/// (`,` layout-specification)? (`,` memory-space)? `>`
///
/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
///
/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
/// layout-specification ::= semi-affine-map | strided-layout | attribute
/// memory-space ::= integer-literal | attribute
///
Type Parser::parseMemRefType() {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_memref);
if (parseToken(Token::less, "expected '<' in memref type"))
return nullptr;
bool isUnranked;
SmallVector<int64_t, 4> dimensions;
if (consumeIf(Token::star)) {
// This is an unranked memref type.
isUnranked = true;
if (parseXInDimensionList())
return nullptr;
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions))
return nullptr;
}
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType)
return nullptr;
// Check that memref is formed from allowed types.
if (!BaseMemRefType::isValidElementType(elementType))
return emitError(typeLoc, "invalid memref element type"), nullptr;
MemRefLayoutAttrInterface layout;
Attribute memorySpace;
auto parseElt = [&]() -> ParseResult {
// Either it is MemRefLayoutAttrInterface or memory space attribute.
Attribute attr = parseAttribute();
if (!attr)
return failure();
if (isa<MemRefLayoutAttrInterface>(attr)) {
layout = cast<MemRefLayoutAttrInterface>(attr);
} else if (memorySpace) {
return emitError("multiple memory spaces specified in memref type");
} else {
memorySpace = attr;
return success();
}
if (isUnranked)
return emitError("cannot have affine map for unranked memref type");
if (memorySpace)
return emitError("expected memory space to be last in memref type");
return success();
};
// Parse a list of mappings and address space if present.
if (!consumeIf(Token::greater)) {
// Parse comma separated list of affine maps, followed by memory space.
if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
parseCommaSeparatedListUntil(Token::greater, parseElt,
/*allowEmptyList=*/false)) {
return nullptr;
}
}
if (isUnranked)
return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
return getChecked<MemRefType>(loc, dimensions, elementType, layout,
memorySpace);
}
/// Parse any type except the function type.
///
/// non-function-type ::= integer-type
/// | index-type
/// | float-type
/// | extended-type
/// | vector-type
/// | tensor-type
/// | memref-type
/// | complex-type
/// | tuple-type
/// | none-type
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128`
/// none-type ::= `none`
///
Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
return (emitWrongTokenError("expected non-function type"), nullptr);
case Token::kw_memref:
return parseMemRefType();
case Token::kw_tensor:
return parseTensorType();
case Token::kw_complex:
return parseComplexType();
case Token::kw_tuple:
return parseTupleType();
case Token::kw_vector:
return parseVectorType();
// integer-type
case Token::inttype: {
auto width = getToken().getIntTypeBitwidth();
if (!width.has_value())
return (emitError("invalid integer width"), nullptr);
if (*width > IntegerType::kMaxWidth) {
emitError(getToken().getLoc(), "integer bitwidth is limited to ")
<< IntegerType::kMaxWidth << " bits";
return nullptr;
}
IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
if (std::optional<bool> signedness = getToken().getIntTypeSignedness())
signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
consumeToken(Token::inttype);
return IntegerType::get(getContext(), *width, signSemantics);
}
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getType<Float128Type>();
// index-type
case Token::kw_index:
consumeToken(Token::kw_index);
return builder.getIndexType();
// none-type
case Token::kw_none:
consumeToken(Token::kw_none);
return builder.getNoneType();
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
// Handle completion of a dialect type.
case Token::code_complete:
if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
return parseExtendedType();
return codeCompleteType();
}
}
/// Parse a tensor type.
///
/// tensor-type ::= `tensor` `<` dimension-list type `>`
/// dimension-list ::= dimension-list-ranked | `*x`
///
Type Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
if (parseToken(Token::less, "expected '<' in tensor type"))
return nullptr;
bool isUnranked;
SmallVector<int64_t, 4> dimensions;
if (consumeIf(Token::star)) {
// This is an unranked tensor type.
isUnranked = true;
if (parseXInDimensionList())
return nullptr;
} else {
isUnranked = false;
if (parseDimensionListRanked(dimensions))
return nullptr;
}
// Parse the element type.
auto elementTypeLoc = getToken().getLoc();
auto elementType = parseType();
// Parse an optional encoding attribute.
Attribute encoding;
if (consumeIf(Token::comma)) {
auto parseResult = parseOptionalAttribute(encoding);
if (parseResult.has_value()) {
if (failed(parseResult.value()))
return nullptr;
if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
if (failed(v.verifyEncoding(dimensions, elementType,
[&] { return emitError(); })))
return nullptr;
}
}
}
if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
return nullptr;
if (!TensorType::isValidElementType(elementType))
return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
if (isUnranked) {
if (encoding)
return emitError("cannot apply encoding to unranked tensor"), nullptr;
return UnrankedTensorType::get(elementType);
}
return RankedTensorType::get(dimensions, elementType, encoding);
}
/// Parse a tuple type.
///
/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
///
Type Parser::parseTupleType() {
consumeToken(Token::kw_tuple);
// Parse the '<'.
if (parseToken(Token::less, "expected '<' in tuple type"))
return nullptr;
// Check for an empty tuple by directly parsing '>'.
if (consumeIf(Token::greater))
return TupleType::get(getContext());
// Parse the element types and the '>'.
SmallVector<Type, 4> types;
if (parseTypeListNoParens(types) ||
parseToken(Token::greater, "expected '>' in tuple type"))
return nullptr;
return TupleType::get(getContext(), types);
}
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
///
VectorType Parser::parseVectorType() {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type"))
return nullptr;
// Parse the dimensions.
SmallVector<int64_t, 4> dimensions;
SmallVector<bool, 4> scalableDims;
if (parseVectorDimensionList(dimensions, scalableDims))
return nullptr;
// Parse the element type.
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
/// For i-th dimension, `scalableDims[i]` contains either:
/// * `false` for a non-scalable dimension (e.g. `4`),
/// * `true` for a scalable dimension (e.g. `[4]`).
///
/// vector-dim-list := (static-dim-list `x`)?
/// static-dim-list ::= static-dim (`x` static-dim)*
/// static-dim ::= (decimal-literal | `[` decimal-literal `]`)
///
ParseResult
Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
SmallVectorImpl<bool> &scalableDims) {
// If there is a set of fixed-length dimensions, consume it
while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
int64_t value;
bool scalable = consumeIf(Token::l_square);
if (parseIntegerInDimensionList(value))
return failure();
dimensions.push_back(value);
if (scalable) {
if (!consumeIf(Token::r_square))
return emitWrongTokenError("missing ']' closing scalable dimension");
}
scalableDims.push_back(scalable);
// Make sure we have an 'x' or something like 'xbf32'.
if (parseXInDimensionList())
return failure();
}
return success();
}
/// Parse a dimension list of a tensor or memref type. This populates the
/// dimension list, using ShapedType::kDynamic for the `?` dimensions if
/// `allowDynamic` is set and errors out on `?` otherwise. Parsing the trailing
/// `x` is configurable.
///
/// dimension-list ::= eps | dimension (`x` dimension)*
/// dimension-list-with-trailing-x ::= (dimension `x`)*
/// dimension ::= `?` | decimal-literal
///
/// When `allowDynamic` is not set, this is used to parse:
///
/// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
/// static-dimension-list-with-trailing-x ::= (dimension `x`)*
ParseResult
Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic, bool withTrailingX) {
auto parseDim = [&]() -> LogicalResult {
auto loc = getToken().getLoc();
if (consumeIf(Token::question)) {
if (!allowDynamic)
return emitError(loc, "expected static shape");
dimensions.push_back(ShapedType::kDynamic);
} else {
int64_t value;
if (failed(parseIntegerInDimensionList(value)))
return failure();
dimensions.push_back(value);
}
return success();
};
if (withTrailingX) {
while (getToken().isAny(Token::integer, Token::question)) {
if (failed(parseDim()) || failed(parseXInDimensionList()))
return failure();
}
return success();
}
if (getToken().isAny(Token::integer, Token::question)) {
if (failed(parseDim()))
return failure();
while (getToken().is(Token::bare_identifier) &&
getTokenSpelling()[0] == 'x') {
if (failed(parseXInDimensionList()) || failed(parseDim()))
return failure();
}
}
return success();
}
ParseResult Parser::parseIntegerInDimensionList(int64_t &value) {
// Hexadecimal integer literals (starting with `0x`) are not allowed in
// aggregate type declarations. Therefore, `0xf32` should be processed as
// a sequence of separate elements `0`, `x`, `f32`.
if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
// We can get here only if the token is an integer literal. Hexadecimal
// integer literals can only start with `0x` (`1x` wouldn't lex as a
// literal, just `1` would, at which point we don't get into this
// branch).
assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
value = 0;
state.lex.resetPointer(getTokenSpelling().data() + 1);
consumeToken();
} else {
// Make sure this integer value is in bound and valid.
std::optional<uint64_t> dimension = getToken().getUInt64IntegerValue();
if (!dimension ||
*dimension > (uint64_t)std::numeric_limits<int64_t>::max())
return emitError("invalid dimension");
value = (int64_t)*dimension;
consumeToken(Token::integer);
}
return success();
}
/// Parse an 'x' token in a dimension list, handling the case where the x is
/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
/// token.
ParseResult Parser::parseXInDimensionList() {
if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
return emitWrongTokenError("expected 'x' in dimension list");
// If we had a prefix of 'x', lex the next token immediately after the 'x'.
if (getTokenSpelling().size() != 1)
state.lex.resetPointer(getTokenSpelling().data() + 1);
// Consume the 'x'.
consumeToken(Token::bare_identifier);
return success();
}