blob: abd1fbdaf8c6494604e2691349e3f49bd6fcc56d [file] [log] [blame]
//===- AttrOrTypeFormatGen.cpp - MLIR attribute and type format generator -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "AttrOrTypeFormatGen.h"
#include "FormatGen.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/AttrOrTypeDef.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
//===----------------------------------------------------------------------===//
// Element
//===----------------------------------------------------------------------===//
namespace {
/// This class represents an instance of a variable element. A variable refers
/// to an attribute or type parameter.
class ParameterElement
: public VariableElementBase<VariableElement::Parameter> {
public:
ParameterElement(AttrOrTypeParameter param) : param(param) {}
/// Get the parameter in the element.
const AttrOrTypeParameter &getParam() const { return param; }
/// Indicate if this variable is printed "qualified" (that is it is
/// prefixed with the `#dialect.mnemonic`).
bool shouldBeQualified() { return shouldBeQualifiedFlag; }
void setShouldBeQualified(bool qualified = true) {
shouldBeQualifiedFlag = qualified;
}
/// Returns true if the element contains an optional parameter.
bool isOptional() const { return param.isOptional(); }
/// Returns the name of the parameter.
StringRef getName() const { return param.getName(); }
/// Return the code to check whether the parameter is present.
auto genIsPresent(FmtContext &ctx, const Twine &self) const {
assert(isOptional() && "cannot guard on a mandatory parameter");
std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
return tgfmt(getParam().getComparator(), &ctx);
}
/// Generate the code to check whether the parameter should be printed.
MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
assert(isOptional() && "cannot guard on a mandatory parameter");
std::string self = param.getAccessorName() + "()";
return os << "!(" << genIsPresent(ctx, self) << ")";
}
private:
bool shouldBeQualifiedFlag = false;
AttrOrTypeParameter param;
};
/// Shorthand functions that can be used with ranged-based conditions.
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
/// Base class for a directive that contains references to multiple variables.
template <DirectiveElement::Kind DirectiveKind>
class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
public:
using Base = ParamsDirectiveBase<DirectiveKind>;
ParamsDirectiveBase(std::vector<ParameterElement *> &&params)
: params(std::move(params)) {}
/// Get the parameters contained in this directive.
ArrayRef<ParameterElement *> getParams() const { return params; }
/// Get the number of parameters.
unsigned getNumParams() const { return params.size(); }
/// Take all of the parameters from this directive.
std::vector<ParameterElement *> takeParams() { return std::move(params); }
/// Returns true if there are optional parameters present.
bool hasOptionalParams() const {
return llvm::any_of(getParams(), paramIsOptional);
}
private:
/// The parameters captured by this directive.
std::vector<ParameterElement *> params;
};
/// This class represents a `params` directive that refers to all parameters
/// of an attribute or type. When used as a top-level directive, it generates
/// a format of the form:
///
/// (param-value (`,` param-value)*)?
///
/// When used as an argument to another directive that accepts variables,
/// `params` can be used in place of manually listing all parameters of an
/// attribute or type.
class ParamsDirective : public ParamsDirectiveBase<DirectiveElement::Params> {
public:
using Base::Base;
};
/// This class represents a `struct` directive that generates a struct format
/// of the form:
///
/// `{` param-name `=` param-value (`,` param-name `=` param-value)* `}`
///
class StructDirective : public ParamsDirectiveBase<DirectiveElement::Struct> {
public:
using Base::Base;
};
} // namespace
//===----------------------------------------------------------------------===//
// Format Strings
//===----------------------------------------------------------------------===//
/// Default parser for attribute or type parameters.
static const char *const defaultParameterParser =
"::mlir::FieldParser<$0>::parse($_parser)";
/// Default printer for attribute or type parameters.
static const char *const defaultParameterPrinter =
"$_printer.printStrippedAttrOrType($_self)";
/// Qualified printer for attribute or type parameters: it does not elide
/// dialect and mnemonic.
static const char *const qualifiedParameterPrinter = "$_printer << $_self";
/// Print an error when failing to parse an element.
///
/// $0: The parameter C++ class name.
static const char *const parserErrorStr =
"$_parser.emitError($_parser.getCurrentLocation(), ";
/// Code format to parse a variable. Separate by lines because variable parsers
/// may be generated inside other directives, which requires indentation.
///
/// {0}: The parameter name.
/// {1}: The parse code for the parameter.
/// {2}: Code template for printing an error.
/// {3}: Name of the attribute or type.
/// {4}: C++ class of the parameter.
static const char *const variableParser = R"(
// Parse variable '{0}'
_result_{0} = {1};
if (::mlir::failed(_result_{0})) {{
{2}"failed to parse {3} parameter '{0}' which is to be a `{4}`");
return {{};
}
)";
//===----------------------------------------------------------------------===//
// DefFormat
//===----------------------------------------------------------------------===//
namespace {
class DefFormat {
public:
DefFormat(const AttrOrTypeDef &def, std::vector<FormatElement *> &&elements)
: def(def), elements(std::move(elements)) {}
/// Generate the attribute or type parser.
void genParser(MethodBody &os);
/// Generate the attribute or type printer.
void genPrinter(MethodBody &os);
private:
/// Generate the parser code for a specific format element.
void genElementParser(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a literal.
void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os,
bool isOptional = false);
/// Generate the parser code for a variable.
void genVariableParser(ParameterElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `params` directive.
void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `struct` directive.
void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the parser code for a `custom` directive.
void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
bool isOptional = false);
/// Generate the parser code for an optional group.
void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
/// Generate the printer code for a specific format element.
void genElementPrinter(FormatElement *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a literal.
void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a variable.
void genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os,
bool skipGuard = false);
/// Generate a printer for comma-separated parameters.
void genCommaSeparatedPrinter(ArrayRef<ParameterElement *> params,
FmtContext &ctx, MethodBody &os,
function_ref<void(ParameterElement *)> extra);
/// Generate the printer code for a `params` directive.
void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `struct` directive.
void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for a `custom` directive.
void genCustomPrinter(CustomDirective *el, FmtContext &ctx, MethodBody &os);
/// Generate the printer code for an optional group.
void genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os);
/// Generate a printer (or space eraser) for a whitespace element.
void genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
MethodBody &os);
/// The ODS definition of the attribute or type whose format is being used to
/// generate a parser and printer.
const AttrOrTypeDef &def;
/// The list of top-level format elements returned by the assembly format
/// parser.
std::vector<FormatElement *> elements;
/// Flags for printing spaces.
bool shouldEmitSpace = false;
bool lastWasPunctuation = false;
};
} // namespace
//===----------------------------------------------------------------------===//
// ParserGen
//===----------------------------------------------------------------------===//
/// Generate a special-case "parser" for an attribute's self type parameter. The
/// self type parameter has special handling in the assembly format in that it
/// is derived from the optional trailing colon type after the attribute.
static void genAttrSelfTypeParser(MethodBody &os, const FmtContext &ctx,
const AttributeSelfTypeParameter &param) {
// "Parser" for an attribute self type parameter that checks the
// optionally-parsed trailing colon type.
//
// $0: The C++ storage class of the type parameter.
// $1: The self type parameter name.
const char *const selfTypeParser = R"(
if ($_type) {
if (auto reqType = ::llvm::dyn_cast<$0>($_type)) {
_result_$1 = reqType;
} else {
$_parser.emitError($_loc, "invalid kind of type specified");
return {};
}
})";
// If the attribute self type parameter is required, emit code that emits an
// error if the trailing type was not parsed.
const char *const selfTypeRequired = R"( else {
$_parser.emitError($_loc, "expected a trailing type");
return {};
})";
os << tgfmt(selfTypeParser, &ctx, param.getCppStorageType(), param.getName());
if (!param.isOptional())
os << tgfmt(selfTypeRequired, &ctx);
os << "\n";
}
void DefFormat::genParser(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_parser", "odsParser");
ctx.addSubst("_ctxt", "odsParser.getContext()");
ctx.withBuilder("odsBuilder");
if (isa<AttrDef>(def))
ctx.addSubst("_type", "odsType");
os.indent();
os << "::mlir::Builder odsBuilder(odsParser.getContext());\n";
// Store the initial location of the parser.
ctx.addSubst("_loc", "odsLoc");
os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n"
"(void) $_loc;\n",
&ctx);
// Declare variables to store all of the parameters. Allocated parameters
// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store
// FailureOr<T> to defer type construction for parameters that are parsed in
// a loop (parsers return FailureOr anyways).
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
for (const AttrOrTypeParameter &param : params) {
os << formatv("::mlir::FailureOr<{0}> _result_{1};\n",
param.getCppStorageType(), param.getName());
if (auto *selfTypeParam = dyn_cast<AttributeSelfTypeParameter>(&param))
genAttrSelfTypeParser(os, ctx, *selfTypeParam);
}
// Generate call to each parameter parser.
for (FormatElement *el : elements)
genElementParser(el, ctx, os);
// Emit an assert for each mandatory parameter. Triggering an assert means
// the generated parser is incorrect (i.e. there is a bug in this code).
for (const AttrOrTypeParameter &param : params) {
if (param.isOptional())
continue;
os << formatv("assert(::mlir::succeeded(_result_{0}));\n", param.getName());
}
// Generate call to the attribute or type builder. Use the checked getter
// if one was generated.
if (def.genVerifyDecl()) {
os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
} else {
os << tgfmt("return $0::get($_parser.getContext()", &ctx,
def.getCppClassName());
}
for (const AttrOrTypeParameter &param : params) {
os << ",\n ";
std::string paramSelfStr;
llvm::raw_string_ostream selfOs(paramSelfStr);
if (std::optional<StringRef> defaultValue = param.getDefaultValue()) {
selfOs << formatv("(_result_{0}.value_or(", param.getName())
<< tgfmt(*defaultValue, &ctx) << "))";
} else {
selfOs << formatv("(*_result_{0})", param.getName());
}
ctx.addSubst(param.getName(), selfOs.str());
os << param.getCppType() << "("
<< tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str()))
<< ")";
}
os << ");";
}
void DefFormat::genElementParser(FormatElement *el, FmtContext &ctx,
MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralParser(literal->getSpelling(), ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
return genVariableParser(var, ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsParser(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructParser(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomParser(custom, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
return genOptionalGroupParser(optional, ctx, os);
if (isa<WhitespaceElement>(el))
return;
llvm_unreachable("unknown format element");
}
void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx,
MethodBody &os, bool isOptional) {
os << "// Parse literal '" << value << "'\n";
os << tgfmt("if ($_parser.parse", &ctx);
if (isOptional)
os << "Optional";
if (value.front() == '_' || isalpha(value.front())) {
os << "Keyword(\"" << value << "\")";
} else {
os << StringSwitch<StringRef>(value)
.Case("->", "Arrow")
.Case(":", "Colon")
.Case(",", "Comma")
.Case("=", "Equal")
.Case("<", "Less")
.Case(">", "Greater")
.Case("{", "LBrace")
.Case("}", "RBrace")
.Case("(", "LParen")
.Case(")", "RParen")
.Case("[", "LSquare")
.Case("]", "RSquare")
.Case("?", "Question")
.Case("+", "Plus")
.Case("*", "Star")
.Case("...", "Ellipsis")
<< "()";
}
if (isOptional) {
// Leave the `if` unclosed to guard optional groups.
return;
}
// Parser will emit an error
os << ") return {};\n";
}
void DefFormat::genVariableParser(ParameterElement *el, FmtContext &ctx,
MethodBody &os) {
// Check for a custom parser. Use the default attribute parser otherwise.
const AttrOrTypeParameter &param = el->getParam();
auto customParser = param.getParser();
auto parser =
customParser ? *customParser : StringRef(defaultParameterParser);
os << formatv(variableParser, param.getName(),
tgfmt(parser, &ctx, param.getCppStorageType()),
tgfmt(parserErrorStr, &ctx), def.getName(), param.getCppType());
}
void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
MethodBody &os) {
os << "// Parse parameter list\n";
// If there are optional parameters, we need to switch to `parseOptionalComma`
// if there are no more required parameters after a certain point.
bool hasOptional = el->hasOptionalParams();
if (hasOptional) {
// Wrap everything in a do-while so that we can `break`.
os << "do {\n";
os.indent();
}
ArrayRef<ParameterElement *> params = el->getParams();
using IteratorT = ParameterElement *const *;
IteratorT it = params.begin();
// Find the last required parameter. Commas become optional aftewards.
// Note: IteratorT's copy assignment is deleted.
ParameterElement *lastReq = nullptr;
for (ParameterElement *param : params)
if (!param->isOptional())
lastReq = param;
IteratorT lastReqIt = lastReq ? llvm::find(params, lastReq) : params.begin();
auto eachFn = [&](ParameterElement *el) { genVariableParser(el, ctx, os); };
auto betweenFn = [&](IteratorT it) {
ParameterElement *el = *std::prev(it);
// Parse a comma if the last optional parameter had a value.
if (el->isOptional()) {
os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
el->getName(),
el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
os.indent();
}
if (it <= lastReqIt) {
genLiteralParser(",", ctx, os);
} else {
genLiteralParser(",", ctx, os, /*isOptional=*/true);
os << ") break;\n";
}
if (el->isOptional())
os.unindent() << "}\n";
};
// llvm::interleave
if (it != params.end()) {
eachFn(*it++);
for (IteratorT e = params.end(); it != e; ++it) {
betweenFn(it);
eachFn(*it);
}
}
if (hasOptional)
os.unindent() << "} while(false);\n";
}
void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
MethodBody &os) {
// Loop declaration for struct parser with only required parameters.
//
// $0: Number of expected parameters.
const char *const loopHeader = R"(
for (unsigned odsStructIndex = 0; odsStructIndex < $0; ++odsStructIndex) {
)";
// Loop body start for struct parser.
const char *const loopStart = R"(
::llvm::StringRef _paramKey;
if ($_parser.parseKeyword(&_paramKey)) {
$_parser.emitError($_parser.getCurrentLocation(),
"expected a parameter name in struct");
return {};
}
if (!_loop_body(_paramKey)) return {};
)";
// Struct parser loop end. Check for duplicate or unknown struct parameters.
//
// {0}: Code template for printing an error.
const char *const loopEnd = R"({{
{0}"duplicate or unknown struct parameter name: ") << _paramKey;
return {{};
}
)";
// Struct parser loop terminator. Parse a comma except on the last element.
//
// {0}: Number of elements in the struct.
const char *const loopTerminator = R"(
if ((odsStructIndex != {0} - 1) && odsParser.parseComma())
return {{};
}
)";
// Check that a mandatory parameter was parse.
//
// {0}: Name of the parameter.
const char *const checkParam = R"(
if (!_seen_{0}) {
{1}"struct is missing required parameter: ") << "{0}";
return {{};
}
)";
// First iteration of the loop parsing an optional struct.
const char *const optionalStructFirst = R"(
::llvm::StringRef _paramKey;
if (!$_parser.parseOptionalKeyword(&_paramKey)) {
if (!_loop_body(_paramKey)) return {};
while (!$_parser.parseOptionalComma()) {
)";
os << "// Parse parameter struct\n";
// Declare a "seen" variable for each key.
for (ParameterElement *param : el->getParams())
os << formatv("bool _seen_{0} = false;\n", param->getName());
// Generate the body of the parsing loop inside a lambda.
os << "{\n";
os.indent()
<< "const auto _loop_body = [&](::llvm::StringRef _paramKey) -> bool {\n";
genLiteralParser("=", ctx, os.indent());
for (ParameterElement *param : el->getParams()) {
os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n"
" _seen_{0} = true;\n",
param->getName());
genVariableParser(param, ctx, os.indent());
os.unindent() << "} else ";
// Print the check for duplicate or unknown parameter.
}
os.getStream().printReindented(strfmt(loopEnd, tgfmt(parserErrorStr, &ctx)));
os << "return true;\n";
os.unindent() << "};\n";
// Generate the parsing loop. If optional parameters are present, then the
// parse loop is guarded by commas.
unsigned numOptional = llvm::count_if(el->getParams(), paramIsOptional);
if (numOptional) {
// If the struct itself is optional, pull out the first iteration.
if (numOptional == el->getNumParams()) {
os.getStream().printReindented(tgfmt(optionalStructFirst, &ctx).str());
os.indent();
} else {
os << "do {\n";
}
} else {
os.getStream().printReindented(
tgfmt(loopHeader, &ctx, el->getNumParams()).str());
}
os.indent();
os.getStream().printReindented(tgfmt(loopStart, &ctx).str());
os.unindent();
// Print the loop terminator. For optional parameters, we have to check that
// all mandatory parameters have been parsed.
// The whole struct is optional if all its parameters are optional.
if (numOptional) {
if (numOptional == el->getNumParams()) {
os << "}\n";
os.unindent() << "}\n";
} else {
os << tgfmt("} while(!$_parser.parseOptionalComma());\n", &ctx);
for (ParameterElement *param : el->getParams()) {
if (param->isOptional())
continue;
os.getStream().printReindented(
strfmt(checkParam, param->getName(), tgfmt(parserErrorStr, &ctx)));
}
}
} else {
// Because the loop loops N times and each non-failing iteration sets 1 of
// N flags, successfully exiting the loop means that all parameters have
// been seen. `parseOptionalComma` would cause issues with any formats that
// use "struct(...) `,`" beacuse structs aren't sounded by braces.
os.getStream().printReindented(strfmt(loopTerminator, el->getNumParams()));
}
os.unindent() << "}\n";
}
void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
MethodBody &os, bool isOptional) {
os << "{\n";
os.indent();
// Bound variables are passed directly to the parser as `FailureOr<T> &`.
// Referenced variables are passed as `T`. The custom parser fails if it
// returns failure or if any of the required parameters failed.
os << tgfmt("auto odsCustomLoc = $_parser.getCurrentLocation();\n", &ctx);
os << "(void)odsCustomLoc;\n";
os << tgfmt("auto odsCustomResult = parse$0($_parser", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg))
os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
<< ")";
else if (auto *ref = dyn_cast<RefDirective>(arg))
os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
else
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
}
os.unindent() << ");\n";
if (isOptional) {
os << "if (!odsCustomResult.has_value()) return {};\n";
os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
} else {
os << "if (::mlir::failed(odsCustomResult)) return {};\n";
}
for (FormatElement *arg : el->getArguments()) {
if (auto *param = dyn_cast<ParameterElement>(arg)) {
if (param->isOptional())
continue;
os << formatv("if (::mlir::failed(_result_{0})) {{\n", param->getName());
os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
<< "\"custom parser failed to parse parameter '"
<< param->getName() << "'\");\n";
os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
os.unindent() << "}\n";
}
}
os.unindent() << "}\n";
}
void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
ArrayRef<FormatElement *> thenElements =
el->getThenElements(/*parseable=*/true);
FormatElement *first = thenElements.front();
const auto guardOn = [&](auto params) {
os << "if (!(";
llvm::interleave(
params, os,
[&](ParameterElement *el) {
os << formatv("(::mlir::succeeded(_result_{0}) && *_result_{0})",
el->getName());
},
" || ");
os << ")) {\n";
};
if (auto *literal = dyn_cast<LiteralElement>(first)) {
genLiteralParser(literal->getSpelling(), ctx, os, /*isOptional=*/true);
os << ") {\n";
} else if (auto *param = dyn_cast<ParameterElement>(first)) {
genVariableParser(param, ctx, os);
guardOn(llvm::ArrayRef(param));
} else if (auto *params = dyn_cast<ParamsDirective>(first)) {
genParamsParser(params, ctx, os);
guardOn(params->getParams());
} else if (auto *custom = dyn_cast<CustomDirective>(first)) {
os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
os.indent();
genCustomParser(custom, ctx, os, /*isOptional=*/true);
os << "return ::mlir::success();\n";
os.unindent();
os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
os.indent();
os << "return {};\n";
os.unindent();
os << "} else if (result.has_value()) {\n";
} else {
auto *strct = cast<StructDirective>(first);
genStructParser(strct, ctx, os);
guardOn(params->getParams());
}
os.indent();
// Generate the parsers for the rest of the thenElements.
for (FormatElement *element : el->getElseElements(/*parseable=*/true))
genElementParser(element, ctx, os);
os.unindent() << "} else {\n";
os.indent();
for (FormatElement *element : thenElements.drop_front())
genElementParser(element, ctx, os);
os.unindent() << "}\n";
}
//===----------------------------------------------------------------------===//
// PrinterGen
//===----------------------------------------------------------------------===//
void DefFormat::genPrinter(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_printer", "odsPrinter");
ctx.addSubst("_ctxt", "getContext()");
ctx.withBuilder("odsBuilder");
os.indent();
os << "::mlir::Builder odsBuilder(getContext());\n";
// Generate printers.
shouldEmitSpace = true;
lastWasPunctuation = false;
for (FormatElement *el : elements)
genElementPrinter(el, ctx, os);
}
void DefFormat::genElementPrinter(FormatElement *el, FmtContext &ctx,
MethodBody &os) {
if (auto *literal = dyn_cast<LiteralElement>(el))
return genLiteralPrinter(literal->getSpelling(), ctx, os);
if (auto *params = dyn_cast<ParamsDirective>(el))
return genParamsPrinter(params, ctx, os);
if (auto *strct = dyn_cast<StructDirective>(el))
return genStructPrinter(strct, ctx, os);
if (auto *custom = dyn_cast<CustomDirective>(el))
return genCustomPrinter(custom, ctx, os);
if (auto *var = dyn_cast<ParameterElement>(el))
return genVariablePrinter(var, ctx, os);
if (auto *optional = dyn_cast<OptionalElement>(el))
return genOptionalGroupPrinter(optional, ctx, os);
if (auto *whitespace = dyn_cast<WhitespaceElement>(el))
return genWhitespacePrinter(whitespace, ctx, os);
llvm::PrintFatalError("unsupported format element");
}
void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx,
MethodBody &os) {
// Don't insert a space before certain punctuation.
bool needSpace =
shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation);
os << tgfmt("$_printer$0 << \"$1\";\n", &ctx, needSpace ? " << ' '" : "",
value);
// Update the flags.
shouldEmitSpace =
value.size() != 1 || !StringRef("<({[").contains(value.front());
lastWasPunctuation = value.front() != '_' && !isalpha(value.front());
}
void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
MethodBody &os, bool skipGuard) {
const AttrOrTypeParameter &param = el->getParam();
ctx.withSelf(param.getAccessorName() + "()");
// Guard the printer on the presence of optional parameters and that they
// aren't equal to their default values (if they have one).
if (el->isOptional() && !skipGuard) {
el->genPrintGuard(ctx, os << "if (") << ") {\n";
os.indent();
}
// Insert a space before the next parameter, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt("$_printer << ' ';\n", &ctx);
shouldEmitSpace = true;
lastWasPunctuation = false;
if (el->shouldBeQualified())
os << tgfmt(qualifiedParameterPrinter, &ctx) << ";\n";
else if (auto printer = param.getPrinter())
os << tgfmt(*printer, &ctx) << ";\n";
else
os << tgfmt(defaultParameterPrinter, &ctx) << ";\n";
if (el->isOptional() && !skipGuard)
os.unindent() << "}\n";
}
/// Generate code to guard printing on the presence of any optional parameters.
template <typename ParameterRange>
static void guardOnAny(FmtContext &ctx, MethodBody &os, ParameterRange &&params,
bool inverted = false) {
os << "if (";
if (inverted)
os << "!(";
llvm::interleave(
params, os,
[&](ParameterElement *param) { param->genPrintGuard(ctx, os); }, " || ");
if (inverted)
os << ")";
os << ") {\n";
os.indent();
}
void DefFormat::genCommaSeparatedPrinter(
ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
function_ref<void(ParameterElement *)> extra) {
// Emit a space if necessary, but only if the struct is present.
if (shouldEmitSpace || !lastWasPunctuation) {
bool allOptional = llvm::all_of(params, paramIsOptional);
if (allOptional)
guardOnAny(ctx, os, params);
os << tgfmt("$_printer << ' ';\n", &ctx);
if (allOptional)
os.unindent() << "}\n";
}
// The first printed element does not need to emit a comma.
os << "{\n";
os.indent() << "bool _firstPrinted = true;\n";
for (ParameterElement *param : params) {
if (param->isOptional()) {
param->genPrintGuard(ctx, os << "if (") << ") {\n";
os.indent();
}
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
os << "_firstPrinted = false;\n";
extra(param);
shouldEmitSpace = false;
lastWasPunctuation = true;
genVariablePrinter(param, ctx, os);
if (param->isOptional())
os.unindent() << "}\n";
}
os.unindent() << "}\n";
}
void DefFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx,
MethodBody &os) {
genCommaSeparatedPrinter(llvm::to_vector(el->getParams()), ctx, os,
[&](ParameterElement *param) {});
}
void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
MethodBody &os) {
genCommaSeparatedPrinter(
llvm::to_vector(el->getParams()), ctx, os, [&](ParameterElement *param) {
os << tgfmt("$_printer << \"$0 = \";\n", &ctx, param->getName());
});
}
void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx,
MethodBody &os) {
// Insert a space before the custom directive, if necessary.
if (shouldEmitSpace || !lastWasPunctuation)
os << tgfmt("$_printer << ' ';\n", &ctx);
shouldEmitSpace = true;
lastWasPunctuation = false;
os << tgfmt("print$0($_printer", &ctx, el->getName());
os.indent();
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg)) {
os << param->getParam().getAccessorName() << "()";
} else if (auto *ref = dyn_cast<RefDirective>(arg)) {
os << cast<ParameterElement>(ref->getArg())->getParam().getAccessorName()
<< "()";
} else {
os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
}
}
os.unindent() << ");\n";
}
void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
FormatElement *anchor = el->getAnchor();
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
guardOnAny(ctx, os, llvm::ArrayRef(param), el->isInverted());
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
guardOnAny(ctx, os, params->getParams(), el->isInverted());
} else if (auto *strct = dyn_cast<StructDirective>(anchor)) {
guardOnAny(ctx, os, strct->getParams(), el->isInverted());
} else {
auto *custom = cast<CustomDirective>(anchor);
guardOnAny(ctx, os,
llvm::make_filter_range(
llvm::map_range(custom->getArguments(),
[](FormatElement *el) {
return dyn_cast<ParameterElement>(el);
}),
[](ParameterElement *param) { return !!param; }),
el->isInverted());
}
// Generate the printer for the contained elements.
{
llvm::SaveAndRestore shouldEmitSpaceFlag(shouldEmitSpace);
llvm::SaveAndRestore lastWasPunctuationFlag(lastWasPunctuation);
for (FormatElement *element : el->getThenElements())
genElementPrinter(element, ctx, os);
}
os.unindent() << "} else {\n";
os.indent();
for (FormatElement *element : el->getElseElements())
genElementPrinter(element, ctx, os);
os.unindent() << "}\n";
}
void DefFormat::genWhitespacePrinter(WhitespaceElement *el, FmtContext &ctx,
MethodBody &os) {
if (el->getValue() == "\\n") {
// FIXME: The newline should be `printer.printNewLine()`, i.e., handled by
// the printer.
os << tgfmt("$_printer << '\\n';\n", &ctx);
} else if (!el->getValue().empty()) {
os << tgfmt("$_printer << \"$0\";\n", &ctx, el->getValue());
} else {
lastWasPunctuation = true;
}
shouldEmitSpace = false;
}
//===----------------------------------------------------------------------===//
// DefFormatParser
//===----------------------------------------------------------------------===//
namespace {
class DefFormatParser : public FormatParser {
public:
DefFormatParser(llvm::SourceMgr &mgr, const AttrOrTypeDef &def)
: FormatParser(mgr, def.getLoc()[0]), def(def),
seenParams(def.getNumParameters()) {}
/// Parse the attribute or type format and create the format elements.
FailureOr<DefFormat> parse();
protected:
/// Verify the parsed elements.
LogicalResult verify(SMLoc loc, ArrayRef<FormatElement *> elements) override;
/// Verify the elements of a custom directive.
LogicalResult
verifyCustomDirectiveArguments(SMLoc loc,
ArrayRef<FormatElement *> arguments) override;
/// Verify the elements of an optional group.
LogicalResult verifyOptionalGroupElements(SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;
LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
/// Parse an attribute or type variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Context ctx) override;
/// Parse an attribute or type format directive.
FailureOr<FormatElement *>
parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind, Context ctx) override;
private:
/// Parse a `params` directive.
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
/// Seen attribute or type parameters.
BitVector seenParams;
};
} // namespace
LogicalResult DefFormatParser::verify(SMLoc loc,
ArrayRef<FormatElement *> elements) {
// Check that all parameters are referenced in the format.
for (auto [index, param] : llvm::enumerate(def.getParameters())) {
if (param.isOptional())
continue;
if (!seenParams.test(index)) {
if (isa<AttributeSelfTypeParameter>(param))
continue;
return emitError(loc, "format is missing reference to parameter: " +
param.getName());
}
if (isa<AttributeSelfTypeParameter>(param)) {
return emitError(loc,
"unexpected self type parameter in assembly format");
}
}
if (elements.empty())
return success();
// A `struct` directive that contains optional parameters cannot be followed
// by a comma literal, which is ambiguous.
for (auto it : llvm::zip(elements.drop_back(), elements.drop_front())) {
auto *structEl = dyn_cast<StructDirective>(std::get<0>(it));
auto *literalEl = dyn_cast<LiteralElement>(std::get<1>(it));
if (!structEl || !literalEl)
continue;
if (literalEl->getSpelling() == "," && structEl->hasOptionalParams()) {
return emitError(loc, "`struct` directive with optional parameters "
"cannot be followed by a comma literal");
}
}
return success();
}
LogicalResult DefFormatParser::verifyCustomDirectiveArguments(
SMLoc loc, ArrayRef<FormatElement *> arguments) {
// Arguments are fully verified by the parser context.
return success();
}
LogicalResult
DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
ArrayRef<FormatElement *> elements,
FormatElement *anchor) {
// `params` and `struct` directives are allowed only if all the contained
// parameters are optional.
for (FormatElement *el : elements) {
if (auto *param = dyn_cast<ParameterElement>(el)) {
if (!param->isOptional()) {
return emitError(loc,
"parameters in an optional group must be optional");
}
} else if (auto *params = dyn_cast<ParamsDirective>(el)) {
if (llvm::any_of(params->getParams(), paramNotOptional)) {
return emitError(loc, "`params` directive allowed in optional group "
"only if all parameters are optional");
}
} else if (auto *strct = dyn_cast<StructDirective>(el)) {
if (llvm::any_of(strct->getParams(), paramNotOptional)) {
return emitError(loc, "`struct` is only allowed in an optional group "
"if all captured parameters are optional");
}
} else if (auto *custom = dyn_cast<CustomDirective>(el)) {
for (FormatElement *el : custom->getArguments()) {
// If the custom argument is a variable, then it must be optional.
if (auto *param = dyn_cast<ParameterElement>(el))
if (!param->isOptional())
return emitError(loc,
"`custom` is only allowed in an optional group if "
"all captured parameters are optional");
}
}
}
// The anchor must be a parameter or one of the aforementioned directives.
if (anchor) {
if (!isa<ParameterElement, ParamsDirective, StructDirective,
CustomDirective>(anchor)) {
return emitError(
loc, "optional group anchor must be a parameter or directive");
}
// If the anchor is a custom directive, make sure at least one of its
// arguments is a bound parameter.
if (auto *custom = dyn_cast<CustomDirective>(anchor)) {
const auto *bound =
llvm::find_if(custom->getArguments(), [](FormatElement *el) {
return isa<ParameterElement>(el);
});
if (bound == custom->getArguments().end())
return emitError(loc, "`custom` directive with no bound parameters "
"cannot be used as optional group anchor");
}
}
return success();
}
LogicalResult DefFormatParser::markQualified(SMLoc loc,
FormatElement *element) {
if (!isa<ParameterElement>(element))
return emitError(loc, "`qualified` argument list expected a variable");
cast<ParameterElement>(element)->setShouldBeQualified();
return success();
}
FailureOr<DefFormat> DefFormatParser::parse() {
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
if (failed(elements))
return failure();
return DefFormat(def, std::move(*elements));
}
FailureOr<FormatElement *>
DefFormatParser::parseVariableImpl(SMLoc loc, StringRef name, Context ctx) {
// Lookup the parameter.
ArrayRef<AttrOrTypeParameter> params = def.getParameters();
auto *it = llvm::find_if(
params, [&](auto &param) { return param.getName() == name; });
// Check that the parameter reference is valid.
if (it == params.end()) {
return emitError(loc,
def.getName() + " has no parameter named '" + name + "'");
}
auto idx = std::distance(params.begin(), it);
if (ctx != RefDirectiveContext) {
// Check that the variable has not already been bound.
if (seenParams.test(idx))
return emitError(loc, "duplicate parameter '" + name + "'");
seenParams.set(idx);
// Otherwise, to be referenced, a variable must have been bound.
} else if (!seenParams.test(idx) && !isa<AttributeSelfTypeParameter>(*it)) {
return emitError(loc, "parameter '" + name +
"' must be bound before it is referenced");
}
return create<ParameterElement>(*it);
}
FailureOr<FormatElement *>
DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
Context ctx) {
switch (kind) {
case FormatToken::kw_qualified:
return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_params:
return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct:
return parseStructDirective(loc, ctx);
default:
return emitError(loc, "unsupported directive kind");
}
}
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
// directive because parameters are the only things that can be bound.
if (ctx != TopLevelContext && ctx != StructDirectiveContext) {
return emitError(loc, "`params` can only be used at the top-level context "
"or within a `struct` directive");
}
// Collect all of the attribute's or type's parameters and ensure that none of
// the parameters have already been captured.
std::vector<ParameterElement *> vars;
for (const auto &it : llvm::enumerate(def.getParameters())) {
if (seenParams.test(it.index())) {
return emitError(loc, "`params` captures duplicate parameter: " +
it.value().getName());
}
// Self-type parameters are handled separately from the rest of the
// parameters.
if (isa<AttributeSelfTypeParameter>(it.value()))
continue;
seenParams.set(it.index());
vars.push_back(create<ParameterElement>(it.value()));
}
return create<ParamsDirective>(std::move(vars));
}
FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
Context ctx) {
if (ctx != TopLevelContext)
return emitError(loc, "`struct` can only be used at the top-level context");
if (failed(parseToken(FormatToken::l_paren,
"expected '(' before `struct` argument list")))
return failure();
// Parse variables captured by `struct`.
std::vector<ParameterElement *> vars;
// Parse first captured parameter or a `params` directive.
FailureOr<FormatElement *> var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement, ParamsDirective>(*var)) {
return emitError(loc,
"`struct` argument list expected a variable or directive");
}
if (isa<VariableElement>(*var)) {
// Parse any other parameters.
vars.push_back(cast<ParameterElement>(*var));
while (peekToken().is(FormatToken::comma)) {
consumeToken();
var = parseElement(StructDirectiveContext);
if (failed(var) || !isa<VariableElement>(*var))
return emitError(loc, "expected a variable in `struct` argument list");
vars.push_back(cast<ParameterElement>(*var));
}
} else {
// `struct(params)` captures all parameters in the attribute or type.
vars = cast<ParamsDirective>(*var)->takeParams();
}
if (failed(parseToken(FormatToken::r_paren,
"expected ')' at the end of an argument list")))
return failure();
return create<StructDirective>(std::move(vars));
}
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def,
MethodBody &parser,
MethodBody &printer) {
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), SMLoc());
// Parse the custom assembly format>
DefFormatParser fmtParser(mgr, def);
FailureOr<DefFormat> format = fmtParser.parse();
if (failed(format)) {
if (formatErrorIsFatal)
PrintFatalError(def.getLoc(), "failed to parse assembly format");
return;
}
// Generate the parser and printer.
format->genParser(parser);
format->genPrinter(printer);
}