| //===- LLVMTypeSyntax.cpp - Parsing/printing for MLIR LLVM Dialect types --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "llvm/ADT/ScopeExit.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| |
| //===----------------------------------------------------------------------===// |
| // Printing. |
| //===----------------------------------------------------------------------===// |
| |
| /// If the given type is compatible with the LLVM dialect, prints it using |
| /// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise |
| /// prints it as usual. |
| static void dispatchPrint(AsmPrinter &printer, Type type) { |
| if (isCompatibleType(type) && !type.isa<IntegerType, FloatType, VectorType>()) |
| return mlir::LLVM::detail::printType(type, printer); |
| printer.printType(type); |
| } |
| |
| /// Returns the keyword to use for the given type. |
| static StringRef getTypeKeyword(Type type) { |
| return TypeSwitch<Type, StringRef>(type) |
| .Case<LLVMVoidType>([&](Type) { return "void"; }) |
| .Case<LLVMPPCFP128Type>([&](Type) { return "ppc_fp128"; }) |
| .Case<LLVMX86MMXType>([&](Type) { return "x86_mmx"; }) |
| .Case<LLVMTokenType>([&](Type) { return "token"; }) |
| .Case<LLVMLabelType>([&](Type) { return "label"; }) |
| .Case<LLVMMetadataType>([&](Type) { return "metadata"; }) |
| .Case<LLVMFunctionType>([&](Type) { return "func"; }) |
| .Case<LLVMPointerType>([&](Type) { return "ptr"; }) |
| .Case<LLVMFixedVectorType, LLVMScalableVectorType>( |
| [&](Type) { return "vec"; }) |
| .Case<LLVMArrayType>([&](Type) { return "array"; }) |
| .Case<LLVMStructType>([&](Type) { return "struct"; }) |
| .Default([](Type) -> StringRef { |
| llvm_unreachable("unexpected 'llvm' type kind"); |
| }); |
| } |
| |
| /// Prints a structure type. Keeps track of known struct names to handle self- |
| /// or mutually-referring structs without falling into infinite recursion. |
| static void printStructType(AsmPrinter &printer, LLVMStructType type) { |
| // This keeps track of the names of identified structure types that are |
| // currently being printed. Since such types can refer themselves, this |
| // tracking is necessary to stop the recursion: the current function may be |
| // called recursively from AsmPrinter::printType after the appropriate |
| // dispatch. We maintain the invariant of this storage being modified |
| // exclusively in this function, and at most one name being added per call. |
| // TODO: consider having such functionality inside AsmPrinter. |
| thread_local SetVector<StringRef> knownStructNames; |
| unsigned stackSize = knownStructNames.size(); |
| (void)stackSize; |
| auto guard = llvm::make_scope_exit([&]() { |
| assert(knownStructNames.size() == stackSize && |
| "malformed identified stack when printing recursive structs"); |
| }); |
| |
| printer << "<"; |
| if (type.isIdentified()) { |
| printer << '"' << type.getName() << '"'; |
| // If we are printing a reference to one of the enclosing structs, just |
| // print the name and stop to avoid infinitely long output. |
| if (knownStructNames.count(type.getName())) { |
| printer << '>'; |
| return; |
| } |
| printer << ", "; |
| } |
| |
| if (type.isIdentified() && type.isOpaque()) { |
| printer << "opaque>"; |
| return; |
| } |
| |
| if (type.isPacked()) |
| printer << "packed "; |
| |
| // Put the current type on stack to avoid infinite recursion. |
| printer << '('; |
| if (type.isIdentified()) |
| knownStructNames.insert(type.getName()); |
| llvm::interleaveComma(type.getBody(), printer.getStream(), |
| [&](Type subtype) { dispatchPrint(printer, subtype); }); |
| if (type.isIdentified()) |
| knownStructNames.pop_back(); |
| printer << ')'; |
| printer << '>'; |
| } |
| |
| /// Prints a type containing a fixed number of elements. |
| template <typename TypeTy> |
| static void printArrayOrVectorType(AsmPrinter &printer, TypeTy type) { |
| printer << '<' << type.getNumElements() << " x "; |
| dispatchPrint(printer, type.getElementType()); |
| printer << '>'; |
| } |
| |
| /// Prints a function type. |
| static void printFunctionType(AsmPrinter &printer, LLVMFunctionType funcType) { |
| printer << '<'; |
| dispatchPrint(printer, funcType.getReturnType()); |
| printer << " ("; |
| llvm::interleaveComma( |
| funcType.getParams(), printer.getStream(), |
| [&printer](Type subtype) { dispatchPrint(printer, subtype); }); |
| if (funcType.isVarArg()) { |
| if (funcType.getNumParams() != 0) |
| printer << ", "; |
| printer << "..."; |
| } |
| printer << ")>"; |
| } |
| |
| /// Prints the given LLVM dialect type recursively. This leverages closedness of |
| /// the LLVM dialect type system to avoid printing the dialect prefix |
| /// repeatedly. For recursive structures, only prints the name of the structure |
| /// when printing a self-reference. Note that this does not apply to sibling |
| /// references. For example, |
| /// struct<"a", (ptr<struct<"a">>)> |
| /// struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>, |
| /// ptr<struct<"b", (ptr<struct<"c">>)>>)> |
| /// note that "b" is printed twice. |
| void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { |
| if (!type) { |
| printer << "<<NULL-TYPE>>"; |
| return; |
| } |
| |
| printer << getTypeKeyword(type); |
| |
| if (auto ptrType = type.dyn_cast<LLVMPointerType>()) { |
| printer << '<'; |
| dispatchPrint(printer, ptrType.getElementType()); |
| if (ptrType.getAddressSpace() != 0) |
| printer << ", " << ptrType.getAddressSpace(); |
| printer << '>'; |
| return; |
| } |
| |
| if (auto arrayType = type.dyn_cast<LLVMArrayType>()) |
| return printArrayOrVectorType(printer, arrayType); |
| if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>()) |
| return printArrayOrVectorType(printer, vectorType); |
| |
| if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) { |
| printer << "<? x " << vectorType.getMinNumElements() << " x "; |
| dispatchPrint(printer, vectorType.getElementType()); |
| printer << '>'; |
| return; |
| } |
| |
| if (auto structType = type.dyn_cast<LLVMStructType>()) |
| return printStructType(printer, structType); |
| |
| if (auto funcType = type.dyn_cast<LLVMFunctionType>()) |
| return printFunctionType(printer, funcType); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Parsing. |
| //===----------------------------------------------------------------------===// |
| |
| static ParseResult dispatchParse(AsmParser &parser, Type &type); |
| |
| /// Parses an LLVM dialect function type. |
| /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` |
| static LLVMFunctionType parseFunctionType(AsmParser &parser) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Type returnType; |
| if (parser.parseLess() || dispatchParse(parser, returnType) || |
| parser.parseLParen()) |
| return LLVMFunctionType(); |
| |
| // Function type without arguments. |
| if (succeeded(parser.parseOptionalRParen())) { |
| if (succeeded(parser.parseGreater())) |
| return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None, |
| /*isVarArg=*/false); |
| return LLVMFunctionType(); |
| } |
| |
| // Parse arguments. |
| SmallVector<Type, 8> argTypes; |
| do { |
| if (succeeded(parser.parseOptionalEllipsis())) { |
| if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) |
| return LLVMFunctionType(); |
| return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes, |
| /*isVarArg=*/true); |
| } |
| |
| Type arg; |
| if (dispatchParse(parser, arg)) |
| return LLVMFunctionType(); |
| argTypes.push_back(arg); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) |
| return LLVMFunctionType(); |
| return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes, |
| /*isVarArg=*/false); |
| } |
| |
| /// Parses an LLVM dialect pointer type. |
| /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` |
| static LLVMPointerType parsePointerType(AsmParser &parser) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Type elementType; |
| if (parser.parseLess() || dispatchParse(parser, elementType)) |
| return LLVMPointerType(); |
| |
| unsigned addressSpace = 0; |
| if (succeeded(parser.parseOptionalComma()) && |
| failed(parser.parseInteger(addressSpace))) |
| return LLVMPointerType(); |
| if (failed(parser.parseGreater())) |
| return LLVMPointerType(); |
| return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace); |
| } |
| |
| /// Parses an LLVM dialect vector type. |
| /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` |
| /// Supports both fixed and scalable vectors. |
| static Type parseVectorType(AsmParser &parser) { |
| SmallVector<int64_t, 2> dims; |
| llvm::SMLoc dimPos, typePos; |
| Type elementType; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || |
| parser.parseDimensionList(dims, /*allowDynamic=*/true) || |
| parser.getCurrentLocation(&typePos) || |
| dispatchParse(parser, elementType) || parser.parseGreater()) |
| return Type(); |
| |
| // We parsed a generic dimension list, but vectors only support two forms: |
| // - single non-dynamic entry in the list (fixed vector); |
| // - two elements, the first dynamic (indicated by -1) and the second |
| // non-dynamic (scalable vector). |
| if (dims.empty() || dims.size() > 2 || |
| ((dims.size() == 2) ^ (dims[0] == -1)) || |
| (dims.size() == 2 && dims[1] == -1)) { |
| parser.emitError(dimPos) |
| << "expected '? x <integer> x <type>' or '<integer> x <type>'"; |
| return Type(); |
| } |
| |
| bool isScalable = dims.size() == 2; |
| if (isScalable) |
| return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); |
| if (elementType.isSignlessIntOrFloat()) { |
| parser.emitError(typePos) |
| << "cannot use !llvm.vec for built-in primitives, use 'vector' instead"; |
| return Type(); |
| } |
| return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]); |
| } |
| |
| /// Parses an LLVM dialect array type. |
| /// llvm-type ::= `array<` integer `x` llvm-type `>` |
| static LLVMArrayType parseArrayType(AsmParser &parser) { |
| SmallVector<int64_t, 1> dims; |
| llvm::SMLoc sizePos; |
| Type elementType; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || |
| parser.parseDimensionList(dims, /*allowDynamic=*/false) || |
| dispatchParse(parser, elementType) || parser.parseGreater()) |
| return LLVMArrayType(); |
| |
| if (dims.size() != 1) { |
| parser.emitError(sizePos) << "expected ? x <type>"; |
| return LLVMArrayType(); |
| } |
| |
| return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]); |
| } |
| |
| /// Attempts to set the body of an identified structure type. Reports a parsing |
| /// error at `subtypesLoc` in case of failure. |
| static LLVMStructType trySetStructBody(LLVMStructType type, |
| ArrayRef<Type> subtypes, bool isPacked, |
| AsmParser &parser, |
| llvm::SMLoc subtypesLoc) { |
| for (Type t : subtypes) { |
| if (!LLVMStructType::isValidElementType(t)) { |
| parser.emitError(subtypesLoc) |
| << "invalid LLVM structure element type: " << t; |
| return LLVMStructType(); |
| } |
| } |
| |
| if (succeeded(type.setBody(subtypes, isPacked))) |
| return type; |
| |
| parser.emitError(subtypesLoc) |
| << "identified type already used with a different body"; |
| return LLVMStructType(); |
| } |
| |
| /// Parses an LLVM dialect structure type. |
| /// llvm-type ::= `struct<` (string-literal `,`)? `packed`? |
| /// `(` llvm-type-list `)` `>` |
| /// | `struct<` string-literal `>` |
| /// | `struct<` string-literal `, opaque>` |
| static LLVMStructType parseStructType(AsmParser &parser) { |
| // This keeps track of the names of identified structure types that are |
| // currently being parsed. Since such types can refer themselves, this |
| // tracking is necessary to stop the recursion: the current function may be |
| // called recursively from AsmParser::parseType after the appropriate |
| // dispatch. We maintain the invariant of this storage being modified |
| // exclusively in this function, and at most one name being added per call. |
| // TODO: consider having such functionality inside AsmParser. |
| thread_local SetVector<StringRef> knownStructNames; |
| unsigned stackSize = knownStructNames.size(); |
| (void)stackSize; |
| auto guard = llvm::make_scope_exit([&]() { |
| assert(knownStructNames.size() == stackSize && |
| "malformed identified stack when parsing recursive structs"); |
| }); |
| |
| Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); |
| |
| if (failed(parser.parseLess())) |
| return LLVMStructType(); |
| |
| // If we are parsing a self-reference to a recursive struct, i.e. the parsing |
| // stack already contains a struct with the same identifier, bail out after |
| // the name. |
| std::string name; |
| bool isIdentified = succeeded(parser.parseOptionalString(&name)); |
| if (isIdentified) { |
| if (knownStructNames.count(name)) { |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| return LLVMStructType::getIdentifiedChecked( |
| [loc] { return emitError(loc); }, loc.getContext(), name); |
| } |
| if (failed(parser.parseComma())) |
| return LLVMStructType(); |
| } |
| |
| // Handle intentionally opaque structs. |
| llvm::SMLoc kwLoc = parser.getCurrentLocation(); |
| if (succeeded(parser.parseOptionalKeyword("opaque"))) { |
| if (!isIdentified) |
| return parser.emitError(kwLoc, "only identified structs can be opaque"), |
| LLVMStructType(); |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| auto type = LLVMStructType::getOpaqueChecked( |
| [loc] { return emitError(loc); }, loc.getContext(), name); |
| if (!type.isOpaque()) { |
| parser.emitError(kwLoc, "redeclaring defined struct as opaque"); |
| return LLVMStructType(); |
| } |
| return type; |
| } |
| |
| // Check for packedness. |
| bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); |
| if (failed(parser.parseLParen())) |
| return LLVMStructType(); |
| |
| // Fast pass for structs with zero subtypes. |
| if (succeeded(parser.parseOptionalRParen())) { |
| if (failed(parser.parseGreater())) |
| return LLVMStructType(); |
| if (!isIdentified) |
| return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); }, |
| loc.getContext(), {}, isPacked); |
| auto type = LLVMStructType::getIdentifiedChecked( |
| [loc] { return emitError(loc); }, loc.getContext(), name); |
| return trySetStructBody(type, {}, isPacked, parser, kwLoc); |
| } |
| |
| // Parse subtypes. For identified structs, put the identifier of the struct on |
| // the stack to support self-references in the recursive calls. |
| SmallVector<Type, 4> subtypes; |
| llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); |
| do { |
| if (isIdentified) |
| knownStructNames.insert(name); |
| Type type; |
| if (dispatchParse(parser, type)) |
| return LLVMStructType(); |
| subtypes.push_back(type); |
| if (isIdentified) |
| knownStructNames.pop_back(); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseRParen() || parser.parseGreater()) |
| return LLVMStructType(); |
| |
| // Construct the struct with body. |
| if (!isIdentified) |
| return LLVMStructType::getLiteralChecked( |
| [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked); |
| auto type = LLVMStructType::getIdentifiedChecked( |
| [loc] { return emitError(loc); }, loc.getContext(), name); |
| return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); |
| } |
| |
| /// Parses a type appearing inside another LLVM dialect-compatible type. This |
| /// will try to parse any type in full form (including types with the `!llvm` |
| /// prefix), and on failure fall back to parsing the short-hand version of the |
| /// LLVM dialect types without the `!llvm` prefix. |
| static Type dispatchParse(AsmParser &parser, bool allowAny = true) { |
| llvm::SMLoc keyLoc = parser.getCurrentLocation(); |
| |
| // Try parsing any MLIR type. |
| Type type; |
| OptionalParseResult result = parser.parseOptionalType(type); |
| if (result.hasValue()) { |
| if (failed(result.getValue())) |
| return nullptr; |
| if (!allowAny) { |
| parser.emitError(keyLoc) << "unexpected type, expected keyword"; |
| return nullptr; |
| } |
| return type; |
| } |
| |
| // If no type found, fallback to the shorthand form. |
| StringRef key; |
| if (failed(parser.parseKeyword(&key))) |
| return Type(); |
| |
| MLIRContext *ctx = parser.getContext(); |
| return StringSwitch<function_ref<Type()>>(key) |
| .Case("void", [&] { return LLVMVoidType::get(ctx); }) |
| .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) |
| .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); }) |
| .Case("token", [&] { return LLVMTokenType::get(ctx); }) |
| .Case("label", [&] { return LLVMLabelType::get(ctx); }) |
| .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) |
| .Case("func", [&] { return parseFunctionType(parser); }) |
| .Case("ptr", [&] { return parsePointerType(parser); }) |
| .Case("vec", [&] { return parseVectorType(parser); }) |
| .Case("array", [&] { return parseArrayType(parser); }) |
| .Case("struct", [&] { return parseStructType(parser); }) |
| .Default([&] { |
| parser.emitError(keyLoc) << "unknown LLVM type: " << key; |
| return Type(); |
| })(); |
| } |
| |
| /// Helper to use in parse lists. |
| static ParseResult dispatchParse(AsmParser &parser, Type &type) { |
| type = dispatchParse(parser); |
| return success(type != nullptr); |
| } |
| |
| /// Parses one of the LLVM dialect types. |
| Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) { |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| Type type = dispatchParse(parser, /*allowAny=*/false); |
| if (!type) |
| return type; |
| if (!isCompatibleType(type)) { |
| parser.emitError(loc) << "unexpected type, expected keyword"; |
| return nullptr; |
| } |
| return type; |
| } |