blob: 294c15be3589e36954c616fab961547320bd0afc [file] [log] [blame]
//===- AST.h - Node definition for the Toy AST ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the AST for the Toy language. It is optimized for
// simplicity, not efficiency. The AST forms a tree structure where each node
// references its children using std::unique_ptr<>.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TUTORIAL_TOY_AST_H_
#define MLIR_TUTORIAL_TOY_AST_H_
#include "toy/Lexer.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include <vector>
namespace toy {
/// A variable type with shape information.
struct VarType {
std::vector<int64_t> shape;
};
/// Base class for all expression nodes.
class ExprAST {
public:
enum ExprASTKind {
Expr_VarDecl,
Expr_Return,
Expr_Num,
Expr_Literal,
Expr_Var,
Expr_BinOp,
Expr_Call,
Expr_Print,
};
ExprAST(ExprASTKind kind, Location location)
: kind(kind), location(location) {}
virtual ~ExprAST() = default;
ExprASTKind getKind() const { return kind; }
const Location &loc() { return location; }
private:
const ExprASTKind kind;
Location location;
};
/// A block-list of expressions.
using ExprASTList = std::vector<std::unique_ptr<ExprAST>>;
/// Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(Location loc, double val) : ExprAST(Expr_Num, loc), Val(val) {}
double getValue() { return Val; }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Num; }
};
/// Expression class for a literal value.
class LiteralExprAST : public ExprAST {
std::vector<std::unique_ptr<ExprAST>> values;
std::vector<int64_t> dims;
public:
LiteralExprAST(Location loc, std::vector<std::unique_ptr<ExprAST>> values,
std::vector<int64_t> dims)
: ExprAST(Expr_Literal, loc), values(std::move(values)),
dims(std::move(dims)) {}
llvm::ArrayRef<std::unique_ptr<ExprAST>> getValues() { return values; }
llvm::ArrayRef<int64_t> getDims() { return dims; }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Literal; }
};
/// Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string name;
public:
VariableExprAST(Location loc, llvm::StringRef name)
: ExprAST(Expr_Var, loc), name(name) {}
llvm::StringRef getName() { return name; }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Var; }
};
/// Expression class for defining a variable.
class VarDeclExprAST : public ExprAST {
std::string name;
VarType type;
std::unique_ptr<ExprAST> initVal;
public:
VarDeclExprAST(Location loc, llvm::StringRef name, VarType type,
std::unique_ptr<ExprAST> initVal)
: ExprAST(Expr_VarDecl, loc), name(name), type(std::move(type)),
initVal(std::move(initVal)) {}
llvm::StringRef getName() { return name; }
ExprAST *getInitVal() { return initVal.get(); }
const VarType &getType() { return type; }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_VarDecl; }
};
/// Expression class for a return operator.
class ReturnExprAST : public ExprAST {
llvm::Optional<std::unique_ptr<ExprAST>> expr;
public:
ReturnExprAST(Location loc, llvm::Optional<std::unique_ptr<ExprAST>> expr)
: ExprAST(Expr_Return, loc), expr(std::move(expr)) {}
llvm::Optional<ExprAST *> getExpr() {
if (expr.hasValue())
return expr->get();
return llvm::None;
}
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Return; }
};
/// Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char op;
std::unique_ptr<ExprAST> lhs, rhs;
public:
char getOp() { return op; }
ExprAST *getLHS() { return lhs.get(); }
ExprAST *getRHS() { return rhs.get(); }
BinaryExprAST(Location loc, char Op, std::unique_ptr<ExprAST> lhs,
std::unique_ptr<ExprAST> rhs)
: ExprAST(Expr_BinOp, loc), op(Op), lhs(std::move(lhs)),
rhs(std::move(rhs)) {}
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_BinOp; }
};
/// Expression class for function calls.
class CallExprAST : public ExprAST {
std::string callee;
std::vector<std::unique_ptr<ExprAST>> args;
public:
CallExprAST(Location loc, const std::string &callee,
std::vector<std::unique_ptr<ExprAST>> args)
: ExprAST(Expr_Call, loc), callee(callee), args(std::move(args)) {}
llvm::StringRef getCallee() { return callee; }
llvm::ArrayRef<std::unique_ptr<ExprAST>> getArgs() { return args; }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Call; }
};
/// Expression class for builtin print calls.
class PrintExprAST : public ExprAST {
std::unique_ptr<ExprAST> arg;
public:
PrintExprAST(Location loc, std::unique_ptr<ExprAST> arg)
: ExprAST(Expr_Print, loc), arg(std::move(arg)) {}
ExprAST *getArg() { return arg.get(); }
/// LLVM style RTTI
static bool classof(const ExprAST *c) { return c->getKind() == Expr_Print; }
};
/// This class represents the "prototype" for a function, which captures its
/// name, and its argument names (thus implicitly the number of arguments the
/// function takes).
class PrototypeAST {
Location location;
std::string name;
std::vector<std::unique_ptr<VariableExprAST>> args;
public:
PrototypeAST(Location location, const std::string &name,
std::vector<std::unique_ptr<VariableExprAST>> args)
: location(location), name(name), args(std::move(args)) {}
const Location &loc() { return location; }
llvm::StringRef getName() const { return name; }
llvm::ArrayRef<std::unique_ptr<VariableExprAST>> getArgs() { return args; }
};
/// This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> proto;
std::unique_ptr<ExprASTList> body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> proto,
std::unique_ptr<ExprASTList> body)
: proto(std::move(proto)), body(std::move(body)) {}
PrototypeAST *getProto() { return proto.get(); }
ExprASTList *getBody() { return body.get(); }
};
/// This class represents a list of functions to be processed together
class ModuleAST {
std::vector<FunctionAST> functions;
public:
ModuleAST(std::vector<FunctionAST> functions)
: functions(std::move(functions)) {}
auto begin() -> decltype(functions.begin()) { return functions.begin(); }
auto end() -> decltype(functions.end()) { return functions.end(); }
};
void dump(ModuleAST &);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_AST_H_