blob: 4155fd69c589754f3c9ea97d0bad45d7ba0efb58 [file] [log] [blame]
//===- NodePrinter.cpp ----------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Context.h"
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
#include <optional>
using namespace mlir;
using namespace mlir::pdll::ast;
//===----------------------------------------------------------------------===//
// NodePrinter
//===----------------------------------------------------------------------===//
namespace {
class NodePrinter {
public:
NodePrinter(raw_ostream &os) : os(os) {}
/// Print the given type to the stream.
void print(Type type);
/// Print the given node to the stream.
void print(const Node *node);
private:
/// Print a range containing children of a node.
template <typename RangeT,
std::enable_if_t<!std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range) {
if (range.empty())
return;
// Print the first N-1 elements with a prefix of "|-".
auto it = std::begin(range);
for (unsigned i = 0, e = llvm::size(range) - 1; i < e; ++i, ++it)
print(*it);
// Print the last element.
elementIndentStack.back() = true;
print(*it);
}
template <typename RangeT, typename... OthersT,
std::enable_if_t<std::is_convertible<RangeT, const Node *>::value>
* = nullptr>
void printChildren(RangeT &&range, OthersT &&...others) {
printChildren(ArrayRef<const Node *>({range, others...}));
}
/// Print a range containing children of a node, nesting the children under
/// the given label.
template <typename RangeT>
void printChildren(StringRef label, RangeT &&range) {
if (range.empty())
return;
elementIndentStack.reserve(elementIndentStack.size() + 1);
llvm::SaveAndRestore lastElement(elementIndentStack.back(), true);
printIndent();
os << label << "`\n";
elementIndentStack.push_back(/*isLastElt*/ false);
printChildren(std::forward<RangeT>(range));
elementIndentStack.pop_back();
}
/// Print the given derived node to the stream.
void printImpl(const CompoundStmt *stmt);
void printImpl(const EraseStmt *stmt);
void printImpl(const LetStmt *stmt);
void printImpl(const ReplaceStmt *stmt);
void printImpl(const ReturnStmt *stmt);
void printImpl(const RewriteStmt *stmt);
void printImpl(const AttributeExpr *expr);
void printImpl(const CallExpr *expr);
void printImpl(const DeclRefExpr *expr);
void printImpl(const MemberAccessExpr *expr);
void printImpl(const OperationExpr *expr);
void printImpl(const RangeExpr *expr);
void printImpl(const TupleExpr *expr);
void printImpl(const TypeExpr *expr);
void printImpl(const AttrConstraintDecl *decl);
void printImpl(const OpConstraintDecl *decl);
void printImpl(const TypeConstraintDecl *decl);
void printImpl(const TypeRangeConstraintDecl *decl);
void printImpl(const UserConstraintDecl *decl);
void printImpl(const ValueConstraintDecl *decl);
void printImpl(const ValueRangeConstraintDecl *decl);
void printImpl(const NamedAttributeDecl *decl);
void printImpl(const OpNameDecl *decl);
void printImpl(const PatternDecl *decl);
void printImpl(const UserRewriteDecl *decl);
void printImpl(const VariableDecl *decl);
void printImpl(const Module *module);
/// Print the current indent stack.
void printIndent() {
if (elementIndentStack.empty())
return;
for (bool isLastElt : llvm::ArrayRef(elementIndentStack).drop_back())
os << (isLastElt ? " " : " |");
os << (elementIndentStack.back() ? " `" : " |");
}
/// The raw output stream.
raw_ostream &os;
/// A stack of indents and a flag indicating if the current element being
/// printed at that indent is the last element.
SmallVector<bool> elementIndentStack;
};
} // namespace
void NodePrinter::print(Type type) {
// Protect against invalid inputs.
if (!type) {
os << "Type<NULL>";
return;
}
TypeSwitch<Type>(type)
.Case([&](AttributeType) { os << "Attr"; })
.Case([&](ConstraintType) { os << "Constraint"; })
.Case([&](OperationType type) {
os << "Op";
if (std::optional<StringRef> name = type.getName())
os << "<" << *name << ">";
})
.Case([&](RangeType type) {
print(type.getElementType());
os << "Range";
})
.Case([&](RewriteType) { os << "Rewrite"; })
.Case([&](TupleType type) {
os << "Tuple<";
llvm::interleaveComma(
llvm::zip(type.getElementNames(), type.getElementTypes()), os,
[&](auto it) {
if (!std::get<0>(it).empty())
os << std::get<0>(it) << ": ";
this->print(std::get<1>(it));
});
os << ">";
})
.Case([&](TypeType) { os << "Type"; })
.Case([&](ValueType) { os << "Value"; })
.Default([](Type) { llvm_unreachable("unknown AST type"); });
}
void NodePrinter::print(const Node *node) {
printIndent();
os << "-";
elementIndentStack.push_back(/*isLastElt*/ false);
TypeSwitch<const Node *>(node)
.Case<
// Statements.
const CompoundStmt, const EraseStmt, const LetStmt, const ReplaceStmt,
const ReturnStmt, const RewriteStmt,
// Expressions.
const AttributeExpr, const CallExpr, const DeclRefExpr,
const MemberAccessExpr, const OperationExpr, const RangeExpr,
const TupleExpr, const TypeExpr,
// Decls.
const AttrConstraintDecl, const OpConstraintDecl,
const TypeConstraintDecl, const TypeRangeConstraintDecl,
const UserConstraintDecl, const ValueConstraintDecl,
const ValueRangeConstraintDecl, const NamedAttributeDecl,
const OpNameDecl, const PatternDecl, const UserRewriteDecl,
const VariableDecl,
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
elementIndentStack.pop_back();
}
void NodePrinter::printImpl(const CompoundStmt *stmt) {
os << "CompoundStmt " << stmt << "\n";
printChildren(stmt->getChildren());
}
void NodePrinter::printImpl(const EraseStmt *stmt) {
os << "EraseStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr());
}
void NodePrinter::printImpl(const LetStmt *stmt) {
os << "LetStmt " << stmt << "\n";
printChildren(stmt->getVarDecl());
}
void NodePrinter::printImpl(const ReplaceStmt *stmt) {
os << "ReplaceStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr());
printChildren("ReplValues", stmt->getReplExprs());
}
void NodePrinter::printImpl(const ReturnStmt *stmt) {
os << "ReturnStmt " << stmt << "\n";
printChildren(stmt->getResultExpr());
}
void NodePrinter::printImpl(const RewriteStmt *stmt) {
os << "RewriteStmt " << stmt << "\n";
printChildren(stmt->getRootOpExpr(), stmt->getRewriteBody());
}
void NodePrinter::printImpl(const AttributeExpr *expr) {
os << "AttributeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}
void NodePrinter::printImpl(const CallExpr *expr) {
os << "CallExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getCallableExpr());
printChildren("Arguments", expr->getArguments());
}
void NodePrinter::printImpl(const DeclRefExpr *expr) {
os << "DeclRefExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getDecl());
}
void NodePrinter::printImpl(const MemberAccessExpr *expr) {
os << "MemberAccessExpr " << expr << " Member<" << expr->getMemberName()
<< "> Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getParentExpr());
}
void NodePrinter::printImpl(const OperationExpr *expr) {
os << "OperationExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getNameDecl());
printChildren("Operands", expr->getOperands());
printChildren("Result Types", expr->getResultTypes());
printChildren("Attributes", expr->getAttributes());
}
void NodePrinter::printImpl(const RangeExpr *expr) {
os << "RangeExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getElements());
}
void NodePrinter::printImpl(const TupleExpr *expr) {
os << "TupleExpr " << expr << " Type<";
print(expr->getType());
os << ">\n";
printChildren(expr->getElements());
}
void NodePrinter::printImpl(const TypeExpr *expr) {
os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
}
void NodePrinter::printImpl(const AttrConstraintDecl *decl) {
os << "AttrConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const OpConstraintDecl *decl) {
os << "OpConstraintDecl " << decl << "\n";
printChildren(decl->getNameDecl());
}
void NodePrinter::printImpl(const TypeConstraintDecl *decl) {
os << "TypeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const TypeRangeConstraintDecl *decl) {
os << "TypeRangeConstraintDecl " << decl << "\n";
}
void NodePrinter::printImpl(const UserConstraintDecl *decl) {
os << "UserConstraintDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}
void NodePrinter::printImpl(const ValueConstraintDecl *decl) {
os << "ValueConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
os << "ValueRangeConstraintDecl " << decl << "\n";
if (const auto *typeExpr = decl->getTypeExpr())
printChildren(typeExpr);
}
void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
<< ">\n";
printChildren(decl->getValue());
}
void NodePrinter::printImpl(const OpNameDecl *decl) {
os << "OpNameDecl " << decl;
if (std::optional<StringRef> name = decl->getName())
os << " Name<" << *name << ">";
os << "\n";
}
void NodePrinter::printImpl(const PatternDecl *decl) {
os << "PatternDecl " << decl;
if (const Name *name = decl->getName())
os << " Name<" << name->getName() << ">";
if (std::optional<uint16_t> benefit = decl->getBenefit())
os << " Benefit<" << *benefit << ">";
if (decl->hasBoundedRewriteRecursion())
os << " Recursion";
os << "\n";
printChildren(decl->getBody());
}
void NodePrinter::printImpl(const UserRewriteDecl *decl) {
os << "UserRewriteDecl " << decl << " Name<" << decl->getName().getName()
<< "> ResultType<" << decl->getResultType() << ">";
if (std::optional<StringRef> codeBlock = decl->getCodeBlock()) {
os << " Code<";
llvm::printEscapedString(*codeBlock, os);
os << ">";
}
os << "\n";
printChildren("Inputs", decl->getInputs());
printChildren("Results", decl->getResults());
if (const CompoundStmt *body = decl->getBody())
printChildren(body);
}
void NodePrinter::printImpl(const VariableDecl *decl) {
os << "VariableDecl " << decl << " Name<" << decl->getName().getName()
<< "> Type<";
print(decl->getType());
os << ">\n";
if (Expr *initExpr = decl->getInitExpr())
printChildren(initExpr);
auto constraints =
llvm::map_range(decl->getConstraints(),
[](const ConstraintRef &ref) { return ref.constraint; });
printChildren("Constraints", constraints);
}
void NodePrinter::printImpl(const Module *module) {
os << "Module " << module << "\n";
printChildren(module->getChildren());
}
//===----------------------------------------------------------------------===//
// Entry point
//===----------------------------------------------------------------------===//
void Node::print(raw_ostream &os) const { NodePrinter(os).print(this); }
void Type::print(raw_ostream &os) const { NodePrinter(os).print(*this); }