| //===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the parser for the dialect symbols, such as extended |
| // attributes and types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "AsmParserImpl.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "llvm/Support/SourceMgr.h" |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| using llvm::MemoryBuffer; |
| using llvm::SMLoc; |
| using llvm::SourceMgr; |
| |
| namespace { |
| /// This class provides the main implementation of the DialectAsmParser that |
| /// allows for dialects to parse attributes and types. This allows for dialect |
| /// hooking into the main MLIR parsing logic. |
| class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> { |
| public: |
| CustomDialectAsmParser(StringRef fullSpec, Parser &parser) |
| : AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser), |
| fullSpec(fullSpec) {} |
| ~CustomDialectAsmParser() override {} |
| |
| /// Returns the full specification of the symbol being parsed. This allows |
| /// for using a separate parser if necessary. |
| StringRef getFullSymbolSpec() const override { return fullSpec; } |
| |
| private: |
| /// The full symbol specification. |
| StringRef fullSpec; |
| }; |
| } // namespace |
| |
| /// Parse the body of a pretty dialect symbol, which starts and ends with <>'s, |
| /// and may be recursive. Return with the 'prettyName' StringRef encompassing |
| /// the entire pretty name. |
| /// |
| /// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>' |
| /// pretty-dialect-sym-contents ::= pretty-dialect-sym-body |
| /// | '(' pretty-dialect-sym-contents+ ')' |
| /// | '[' pretty-dialect-sym-contents+ ']' |
| /// | '{' pretty-dialect-sym-contents+ '}' |
| /// | '[^[<({>\])}\0]+' |
| /// |
| ParseResult Parser::parsePrettyDialectSymbolName(StringRef &prettyName) { |
| // Pretty symbol names are a relatively unstructured format that contains a |
| // series of properly nested punctuation, with anything else in the middle. |
| // Scan ahead to find it and consume it if successful, otherwise emit an |
| // error. |
| auto *curPtr = getTokenSpelling().data(); |
| |
| SmallVector<char, 8> nestedPunctuation; |
| |
| // Scan over the nested punctuation, bailing out on error and consuming until |
| // we find the end. We know that we're currently looking at the '<', so we |
| // can go until we find the matching '>' character. |
| assert(*curPtr == '<'); |
| do { |
| char c = *curPtr++; |
| switch (c) { |
| case '\0': |
| // This also handles the EOF case. |
| return emitError("unexpected nul or EOF in pretty dialect name"); |
| case '<': |
| case '[': |
| case '(': |
| case '{': |
| nestedPunctuation.push_back(c); |
| continue; |
| |
| case '-': |
| // The sequence `->` is treated as special token. |
| if (*curPtr == '>') |
| ++curPtr; |
| continue; |
| |
| case '>': |
| if (nestedPunctuation.pop_back_val() != '<') |
| return emitError("unbalanced '>' character in pretty dialect name"); |
| break; |
| case ']': |
| if (nestedPunctuation.pop_back_val() != '[') |
| return emitError("unbalanced ']' character in pretty dialect name"); |
| break; |
| case ')': |
| if (nestedPunctuation.pop_back_val() != '(') |
| return emitError("unbalanced ')' character in pretty dialect name"); |
| break; |
| case '}': |
| if (nestedPunctuation.pop_back_val() != '{') |
| return emitError("unbalanced '}' character in pretty dialect name"); |
| break; |
| |
| default: |
| continue; |
| } |
| } while (!nestedPunctuation.empty()); |
| |
| // Ok, we succeeded, remember where we stopped, reset the lexer to know it is |
| // consuming all this stuff, and return. |
| state.lex.resetPointer(curPtr); |
| |
| unsigned length = curPtr - prettyName.begin(); |
| prettyName = StringRef(prettyName.begin(), length); |
| consumeToken(); |
| return success(); |
| } |
| |
| /// Parse an extended dialect symbol. |
| template <typename Symbol, typename SymbolAliasMap, typename CreateFn> |
| static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok, |
| SymbolAliasMap &aliases, |
| CreateFn &&createSymbol) { |
| // Parse the dialect namespace. |
| StringRef identifier = p.getTokenSpelling().drop_front(); |
| auto loc = p.getToken().getLoc(); |
| p.consumeToken(identifierTok); |
| |
| // If there is no '<' token following this, and if the typename contains no |
| // dot, then we are parsing a symbol alias. |
| if (p.getToken().isNot(Token::less) && !identifier.contains('.')) { |
| // Check for an alias for this type. |
| auto aliasIt = aliases.find(identifier); |
| if (aliasIt == aliases.end()) |
| return (p.emitError("undefined symbol alias id '" + identifier + "'"), |
| nullptr); |
| return aliasIt->second; |
| } |
| |
| // Otherwise, we are parsing a dialect-specific symbol. If the name contains |
| // a dot, then this is the "pretty" form. If not, it is the verbose form that |
| // looks like <"...">. |
| std::string symbolData; |
| auto dialectName = identifier; |
| |
| // Handle the verbose form, where "identifier" is a simple dialect name. |
| if (!identifier.contains('.')) { |
| // Consume the '<'. |
| if (p.parseToken(Token::less, "expected '<' in dialect type")) |
| return nullptr; |
| |
| // Parse the symbol specific data. |
| if (p.getToken().isNot(Token::string)) |
| return (p.emitError("expected string literal data in dialect symbol"), |
| nullptr); |
| symbolData = p.getToken().getStringValue(); |
| loc = llvm::SMLoc::getFromPointer(p.getToken().getLoc().getPointer() + 1); |
| p.consumeToken(Token::string); |
| |
| // Consume the '>'. |
| if (p.parseToken(Token::greater, "expected '>' in dialect symbol")) |
| return nullptr; |
| } else { |
| // Ok, the dialect name is the part of the identifier before the dot, the |
| // part after the dot is the dialect's symbol, or the start thereof. |
| auto dotHalves = identifier.split('.'); |
| dialectName = dotHalves.first; |
| auto prettyName = dotHalves.second; |
| loc = llvm::SMLoc::getFromPointer(prettyName.data()); |
| |
| // If the dialect's symbol is followed immediately by a <, then lex the body |
| // of it into prettyName. |
| if (p.getToken().is(Token::less) && |
| prettyName.bytes_end() == p.getTokenSpelling().bytes_begin()) { |
| if (p.parsePrettyDialectSymbolName(prettyName)) |
| return nullptr; |
| } |
| |
| symbolData = prettyName.str(); |
| } |
| |
| // Record the name location of the type remapped to the top level buffer. |
| llvm::SMLoc locInTopLevelBuffer = p.remapLocationToTopLevelBuffer(loc); |
| p.getState().symbols.nestedParserLocs.push_back(locInTopLevelBuffer); |
| |
| // Call into the provided symbol construction function. |
| Symbol sym = createSymbol(dialectName, symbolData, loc); |
| |
| // Pop the last parser location. |
| p.getState().symbols.nestedParserLocs.pop_back(); |
| return sym; |
| } |
| |
| /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
| /// parsing failed, nullptr is returned. The number of bytes read from the input |
| /// string is returned in 'numRead'. |
| template <typename T, typename ParserFn> |
| static T parseSymbol(StringRef inputStr, MLIRContext *context, |
| SymbolState &symbolState, ParserFn &&parserFn, |
| size_t *numRead = nullptr) { |
| SourceMgr sourceMgr; |
| auto memBuffer = MemoryBuffer::getMemBuffer( |
| inputStr, /*BufferName=*/"<mlir_parser_buffer>", |
| /*RequiresNullTerminator=*/false); |
| sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); |
| ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); |
| Parser parser(state); |
| |
| Token startTok = parser.getToken(); |
| T symbol = parserFn(parser); |
| if (!symbol) |
| return T(); |
| |
| // If 'numRead' is valid, then provide the number of bytes that were read. |
| Token endTok = parser.getToken(); |
| if (numRead) { |
| *numRead = static_cast<size_t>(endTok.getLoc().getPointer() - |
| startTok.getLoc().getPointer()); |
| |
| // Otherwise, ensure that all of the tokens were parsed. |
| } else if (startTok.getLoc() != endTok.getLoc() && endTok.isNot(Token::eof)) { |
| parser.emitError(endTok.getLoc(), "encountered unexpected token"); |
| return T(); |
| } |
| return symbol; |
| } |
| |
| /// Parse an extended attribute. |
| /// |
| /// extended-attribute ::= (dialect-attribute | attribute-alias) |
| /// dialect-attribute ::= `#` dialect-namespace `<` `"` attr-data `"` `>` |
| /// dialect-attribute ::= `#` alias-name pretty-dialect-sym-body? |
| /// attribute-alias ::= `#` alias-name |
| /// |
| Attribute Parser::parseExtendedAttr(Type type) { |
| Attribute attr = parseExtendedSymbol<Attribute>( |
| *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, |
| [&](StringRef dialectName, StringRef symbolData, |
| llvm::SMLoc loc) -> Attribute { |
| // Parse an optional trailing colon type. |
| Type attrType = type; |
| if (consumeIf(Token::colon) && !(attrType = parseType())) |
| return Attribute(); |
| |
| // If we found a registered dialect, then ask it to parse the attribute. |
| if (Dialect *dialect = |
| builder.getContext()->getOrLoadDialect(dialectName)) { |
| return parseSymbol<Attribute>( |
| symbolData, state.context, state.symbols, [&](Parser &parser) { |
| CustomDialectAsmParser customParser(symbolData, parser); |
| return dialect->parseAttribute(customParser, attrType); |
| }); |
| } |
| |
| // Otherwise, form a new opaque attribute. |
| return OpaqueAttr::getChecked( |
| [&] { return emitError(loc); }, |
| StringAttr::get(state.context, dialectName), symbolData, |
| attrType ? attrType : NoneType::get(state.context)); |
| }); |
| |
| // Ensure that the attribute has the same type as requested. |
| if (attr && type && attr.getType() != type) { |
| emitError("attribute type different than expected: expected ") |
| << type << ", but got " << attr.getType(); |
| return nullptr; |
| } |
| return attr; |
| } |
| |
| /// Parse an extended type. |
| /// |
| /// extended-type ::= (dialect-type | type-alias) |
| /// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>` |
| /// dialect-type ::= `!` alias-name pretty-dialect-attribute-body? |
| /// type-alias ::= `!` alias-name |
| /// |
| Type Parser::parseExtendedType() { |
| return parseExtendedSymbol<Type>( |
| *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, |
| [&](StringRef dialectName, StringRef symbolData, |
| llvm::SMLoc loc) -> Type { |
| // If we found a registered dialect, then ask it to parse the type. |
| auto *dialect = state.context->getOrLoadDialect(dialectName); |
| |
| if (dialect) { |
| return parseSymbol<Type>( |
| symbolData, state.context, state.symbols, [&](Parser &parser) { |
| CustomDialectAsmParser customParser(symbolData, parser); |
| return dialect->parseType(customParser); |
| }); |
| } |
| |
| // Otherwise, form a new opaque type. |
| return OpaqueType::getChecked( |
| [&] { return emitError(loc); }, |
| StringAttr::get(state.context, dialectName), symbolData); |
| }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // mlir::parseAttribute/parseType |
| //===----------------------------------------------------------------------===// |
| |
| /// Parses a symbol, of type 'T', and returns it if parsing was successful. If |
| /// parsing failed, nullptr is returned. The number of bytes read from the input |
| /// string is returned in 'numRead'. |
| template <typename T, typename ParserFn> |
| static T parseSymbol(StringRef inputStr, MLIRContext *context, size_t &numRead, |
| ParserFn &&parserFn) { |
| SymbolState aliasState; |
| return parseSymbol<T>( |
| inputStr, context, aliasState, |
| [&](Parser &parser) { |
| SourceMgrDiagnosticHandler handler( |
| const_cast<llvm::SourceMgr &>(parser.getSourceMgr()), |
| parser.getContext()); |
| return parserFn(parser); |
| }, |
| &numRead); |
| } |
| |
| Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context) { |
| size_t numRead = 0; |
| return parseAttribute(attrStr, context, numRead); |
| } |
| Attribute mlir::parseAttribute(StringRef attrStr, Type type) { |
| size_t numRead = 0; |
| return parseAttribute(attrStr, type, numRead); |
| } |
| |
| Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context, |
| size_t &numRead) { |
| return parseSymbol<Attribute>(attrStr, context, numRead, [](Parser &parser) { |
| return parser.parseAttribute(); |
| }); |
| } |
| Attribute mlir::parseAttribute(StringRef attrStr, Type type, size_t &numRead) { |
| return parseSymbol<Attribute>( |
| attrStr, type.getContext(), numRead, |
| [type](Parser &parser) { return parser.parseAttribute(type); }); |
| } |
| |
| Type mlir::parseType(StringRef typeStr, MLIRContext *context) { |
| size_t numRead = 0; |
| return parseType(typeStr, context, numRead); |
| } |
| |
| Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t &numRead) { |
| return parseSymbol<Type>(typeStr, context, numRead, |
| [](Parser &parser) { return parser.parseType(); }); |
| } |