blob: 886313c194102eaf7fff7eb5fbee200360aa50f5 [file] [log] [blame]
//===- Parser.h - Toy Language Parser -------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the parser for the Toy language. It processes the Token
// provided by the Lexer and returns an AST.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_PARSER_H
#define MLIR_TUTORIAL_TOY_PARSER_H
#include "toy/AST.h"
#include "toy/Lexer.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <map>
#include <utility>
#include <vector>
namespace toy {
/// This is a simple recursive parser for the Toy language. It produces a well
/// formed AST from a stream of Token supplied by the Lexer. No semantic checks
/// or symbol resolution is performed. For example, variables are referenced by
/// string and the code could reference an undeclared variable and the parsing
/// succeeds.
class Parser {
public:
/// Create a Parser for the supplied lexer.
Parser(Lexer &lexer) : lexer(lexer) {}
/// Parse a full Module. A module is a list of function definitions.
std::unique_ptr<ModuleAST> parseModule() {
lexer.getNextToken(); // prime the lexer
// Parse functions one at a time and accumulate in this vector.
std::vector<FunctionAST> functions;
while (auto f = parseDefinition()) {
functions.push_back(std::move(*f));
if (lexer.getCurToken() == tok_eof)
break;
}
// If we didn't reach EOF, there was an error during parsing
if (lexer.getCurToken() != tok_eof)
return parseError<ModuleAST>("nothing", "at end of module");
return std::make_unique<ModuleAST>(std::move(functions));
}
private:
Lexer &lexer;
/// Parse a return statement.
/// return :== return ; | return expr ;
std::unique_ptr<ReturnExprAST> parseReturn() {
auto loc = lexer.getLastLocation();
lexer.consume(tok_return);
// return takes an optional argument
llvm::Optional<std::unique_ptr<ExprAST>> expr;
if (lexer.getCurToken() != ';') {
expr = parseExpression();
if (!expr)
return nullptr;
}
return std::make_unique<ReturnExprAST>(std::move(loc), std::move(expr));
}
/// Parse a literal number.
/// numberexpr ::= number
std::unique_ptr<ExprAST> parseNumberExpr() {
auto loc = lexer.getLastLocation();
auto result =
std::make_unique<NumberExprAST>(std::move(loc), lexer.getValue());
lexer.consume(tok_number);
return std::move(result);
}
/// Parse a literal array expression.
/// tensorLiteral ::= [ literalList ] | number
/// literalList ::= tensorLiteral | tensorLiteral, literalList
std::unique_ptr<ExprAST> parseTensorLiteralExpr() {
auto loc = lexer.getLastLocation();
lexer.consume(Token('['));
// Hold the list of values at this nesting level.
std::vector<std::unique_ptr<ExprAST>> values;
// Hold the dimensions for all the nesting inside this level.
std::vector<int64_t> dims;
do {
// We can have either another nested array or a number literal.
if (lexer.getCurToken() == '[') {
values.push_back(parseTensorLiteralExpr());
if (!values.back())
return nullptr; // parse error in the nested array.
} else {
if (lexer.getCurToken() != tok_number)
return parseError<ExprAST>("<num> or [", "in literal expression");
values.push_back(parseNumberExpr());
}
// End of this list on ']'
if (lexer.getCurToken() == ']')
break;
// Elements are separated by a comma.
if (lexer.getCurToken() != ',')
return parseError<ExprAST>("] or ,", "in literal expression");
lexer.getNextToken(); // eat ,
} while (true);
if (values.empty())
return parseError<ExprAST>("<something>", "to fill literal expression");
lexer.getNextToken(); // eat ]
/// Fill in the dimensions now. First the current nesting level:
dims.push_back(values.size());
/// If there is any nested array, process all of them and ensure that
/// dimensions are uniform.
if (llvm::any_of(values, [](std::unique_ptr<ExprAST> &expr) {
return llvm::isa<LiteralExprAST>(expr.get());
})) {
auto *firstLiteral = llvm::dyn_cast<LiteralExprAST>(values.front().get());
if (!firstLiteral)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expression");
// Append the nested dimensions to the current level
auto firstDims = firstLiteral->getDims();
dims.insert(dims.end(), firstDims.begin(), firstDims.end());
// Sanity check that shape is uniform across all elements of the list.
for (auto &expr : values) {
auto *exprLiteral = llvm::cast<LiteralExprAST>(expr.get());
if (!exprLiteral)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expression");
if (exprLiteral->getDims() != firstDims)
return parseError<ExprAST>("uniform well-nested dimensions",
"inside literal expression");
}
}
return std::make_unique<LiteralExprAST>(std::move(loc), std::move(values),
std::move(dims));
}
/// parenexpr ::= '(' expression ')'
std::unique_ptr<ExprAST> parseParenExpr() {
lexer.getNextToken(); // eat (.
auto v = parseExpression();
if (!v)
return nullptr;
if (lexer.getCurToken() != ')')
return parseError<ExprAST>(")", "to close expression with parentheses");
lexer.consume(Token(')'));
return v;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression ')'
std::unique_ptr<ExprAST> parseIdentifierExpr() {
std::string name(lexer.getId());
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat identifier.
if (lexer.getCurToken() != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(std::move(loc), name);
// This is a function call.
lexer.consume(Token('('));
std::vector<std::unique_ptr<ExprAST>> args;
if (lexer.getCurToken() != ')') {
while (true) {
if (auto arg = parseExpression())
args.push_back(std::move(arg));
else
return nullptr;
if (lexer.getCurToken() == ')')
break;
if (lexer.getCurToken() != ',')
return parseError<ExprAST>(", or )", "in argument list");
lexer.getNextToken();
}
}
lexer.consume(Token(')'));
// It can be a builtin call to print
if (name == "print") {
if (args.size() != 1)
return parseError<ExprAST>("<single arg>", "as argument to print()");
return std::make_unique<PrintExprAST>(std::move(loc), std::move(args[0]));
}
// Call to a user-defined function
return std::make_unique<CallExprAST>(std::move(loc), name, std::move(args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
/// ::= tensorliteral
std::unique_ptr<ExprAST> parsePrimary() {
switch (lexer.getCurToken()) {
default:
llvm::errs() << "unknown token '" << lexer.getCurToken()
<< "' when expecting an expression\n";
return nullptr;
case tok_identifier:
return parseIdentifierExpr();
case tok_number:
return parseNumberExpr();
case '(':
return parseParenExpr();
case '[':
return parseTensorLiteralExpr();
case ';':
return nullptr;
case '}':
return nullptr;
}
}
/// Recursively parse the right hand side of a binary expression, the ExprPrec
/// argument indicates the precedence of the current binary operator.
///
/// binoprhs ::= ('+' primary)*
std::unique_ptr<ExprAST> parseBinOpRHS(int exprPrec,
std::unique_ptr<ExprAST> lhs) {
// If this is a binop, find its precedence.
while (true) {
int tokPrec = getTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (tokPrec < exprPrec)
return lhs;
// Okay, we know this is a binop.
int binOp = lexer.getCurToken();
lexer.consume(Token(binOp));
auto loc = lexer.getLastLocation();
// Parse the primary expression after the binary operator.
auto rhs = parsePrimary();
if (!rhs)
return parseError<ExprAST>("expression", "to complete binary operator");
// If BinOp binds less tightly with rhs than the operator after rhs, let
// the pending operator take rhs as its lhs.
int nextPrec = getTokPrecedence();
if (tokPrec < nextPrec) {
rhs = parseBinOpRHS(tokPrec + 1, std::move(rhs));
if (!rhs)
return nullptr;
}
// Merge lhs/RHS.
lhs = std::make_unique<BinaryExprAST>(std::move(loc), binOp,
std::move(lhs), std::move(rhs));
}
}
/// expression::= primary binop rhs
std::unique_ptr<ExprAST> parseExpression() {
auto lhs = parsePrimary();
if (!lhs)
return nullptr;
return parseBinOpRHS(0, std::move(lhs));
}
/// type ::= < shape_list >
/// shape_list ::= num | num , shape_list
std::unique_ptr<VarType> parseType() {
if (lexer.getCurToken() != '<')
return parseError<VarType>("<", "to begin type");
lexer.getNextToken(); // eat <
auto type = std::make_unique<VarType>();
while (lexer.getCurToken() == tok_number) {
type->shape.push_back(lexer.getValue());
lexer.getNextToken();
if (lexer.getCurToken() == ',')
lexer.getNextToken();
}
if (lexer.getCurToken() != '>')
return parseError<VarType>(">", "to end type");
lexer.getNextToken(); // eat >
return type;
}
/// Parse a variable declaration, it starts with a `var` keyword followed by
/// and identifier and an optional type (shape specification) before the
/// initializer.
/// decl ::= var identifier [ type ] = expr
std::unique_ptr<VarDeclExprAST> parseDeclaration() {
if (lexer.getCurToken() != tok_var)
return parseError<VarDeclExprAST>("var", "to begin declaration");
auto loc = lexer.getLastLocation();
lexer.getNextToken(); // eat var
if (lexer.getCurToken() != tok_identifier)
return parseError<VarDeclExprAST>("identified",
"after 'var' declaration");
std::string id(lexer.getId());
lexer.getNextToken(); // eat id
std::unique_ptr<VarType> type; // Type is optional, it can be inferred
if (lexer.getCurToken() == '<') {
type = parseType();
if (!type)
return nullptr;
}
if (!type)
type = std::make_unique<VarType>();
lexer.consume(Token('='));
auto expr = parseExpression();
return std::make_unique<VarDeclExprAST>(std::move(loc), std::move(id),
std::move(*type), std::move(expr));
}
/// Parse a block: a list of expression separated by semicolons and wrapped in
/// curly braces.
///
/// block ::= { expression_list }
/// expression_list ::= block_expr ; expression_list
/// block_expr ::= decl | "return" | expr
std::unique_ptr<ExprASTList> parseBlock() {
if (lexer.getCurToken() != '{')
return parseError<ExprASTList>("{", "to begin block");
lexer.consume(Token('{'));
auto exprList = std::make_unique<ExprASTList>();
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
while (lexer.getCurToken() != '}' && lexer.getCurToken() != tok_eof) {
if (lexer.getCurToken() == tok_var) {
// Variable declaration
auto varDecl = parseDeclaration();
if (!varDecl)
return nullptr;
exprList->push_back(std::move(varDecl));
} else if (lexer.getCurToken() == tok_return) {
// Return statement
auto ret = parseReturn();
if (!ret)
return nullptr;
exprList->push_back(std::move(ret));
} else {
// General expression
auto expr = parseExpression();
if (!expr)
return nullptr;
exprList->push_back(std::move(expr));
}
// Ensure that elements are separated by a semicolon.
if (lexer.getCurToken() != ';')
return parseError<ExprASTList>(";", "after expression");
// Ignore empty expressions: swallow sequences of semicolons.
while (lexer.getCurToken() == ';')
lexer.consume(Token(';'));
}
if (lexer.getCurToken() != '}')
return parseError<ExprASTList>("}", "to close block");
lexer.consume(Token('}'));
return exprList;
}
/// prototype ::= def id '(' decl_list ')'
/// decl_list ::= identifier | identifier, decl_list
std::unique_ptr<PrototypeAST> parsePrototype() {
auto loc = lexer.getLastLocation();
if (lexer.getCurToken() != tok_def)
return parseError<PrototypeAST>("def", "in prototype");
lexer.consume(tok_def);
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>("function name", "in prototype");
std::string fnName(lexer.getId());
lexer.consume(tok_identifier);
if (lexer.getCurToken() != '(')
return parseError<PrototypeAST>("(", "in prototype");
lexer.consume(Token('('));
std::vector<std::unique_ptr<VariableExprAST>> args;
if (lexer.getCurToken() != ')') {
do {
std::string name(lexer.getId());
auto loc = lexer.getLastLocation();
lexer.consume(tok_identifier);
auto decl = std::make_unique<VariableExprAST>(std::move(loc), name);
args.push_back(std::move(decl));
if (lexer.getCurToken() != ',')
break;
lexer.consume(Token(','));
if (lexer.getCurToken() != tok_identifier)
return parseError<PrototypeAST>(
"identifier", "after ',' in function parameter list");
} while (true);
}
if (lexer.getCurToken() != ')')
return parseError<PrototypeAST>(")", "to end function prototype");
// success.
lexer.consume(Token(')'));
return std::make_unique<PrototypeAST>(std::move(loc), fnName,
std::move(args));
}
/// Parse a function definition, we expect a prototype initiated with the
/// `def` keyword, followed by a block containing a list of expressions.
///
/// definition ::= prototype block
std::unique_ptr<FunctionAST> parseDefinition() {
auto proto = parsePrototype();
if (!proto)
return nullptr;
if (auto block = parseBlock())
return std::make_unique<FunctionAST>(std::move(proto), std::move(block));
return nullptr;
}
/// Get the precedence of the pending binary operator token.
int getTokPrecedence() {
if (!isascii(lexer.getCurToken()))
return -1;
// 1 is lowest precedence.
switch (static_cast<char>(lexer.getCurToken())) {
case '-':
return 20;
case '+':
return 20;
case '*':
return 40;
default:
return -1;
}
}
/// Helper function to signal errors while parsing, it takes an argument
/// indicating the expected token and another argument giving more context.
/// Location is retrieved from the lexer to enrich the error message.
template <typename R, typename T, typename U = const char *>
std::unique_ptr<R> parseError(T &&expected, U &&context = "") {
auto curToken = lexer.getCurToken();
llvm::errs() << "Parse error (" << lexer.getLastLocation().line << ", "
<< lexer.getLastLocation().col << "): expected '" << expected
<< "' " << context << " but has Token " << curToken;
if (isprint(curToken))
llvm::errs() << " '" << (char)curToken << "'";
llvm::errs() << "\n";
return nullptr;
}
};
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_PARSER_H