| //===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "AttrOrTypeFormatGen.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "mlir/TableGen/AttrOrTypeDef.h" |
| #include "mlir/TableGen/CodeGenHelpers.h" |
| #include "mlir/TableGen/Format.h" |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Interfaces.h" |
| #include "llvm/ADT/Sequence.h" |
| #include "llvm/ADT/SmallSet.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| |
| #define DEBUG_TYPE "mlir-tblgen-attrortypedefgen" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| //===----------------------------------------------------------------------===// |
| // Utility Functions |
| //===----------------------------------------------------------------------===// |
| |
| std::string mlir::tblgen::getParameterAccessorName(StringRef name) { |
| assert(!name.empty() && "parameter has empty name"); |
| auto ret = "get" + name.str(); |
| ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name |
| return ret; |
| } |
| |
| /// Find all the AttrOrTypeDef for the specified dialect. If no dialect |
| /// specified and can only find one dialect's defs, use that. |
| static void collectAllDefs(StringRef selectedDialect, |
| std::vector<llvm::Record *> records, |
| SmallVectorImpl<AttrOrTypeDef> &resultDefs) { |
| auto defs = llvm::map_range( |
| records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); |
| if (defs.empty()) |
| return; |
| |
| StringRef dialectName; |
| if (selectedDialect.empty()) { |
| if (defs.empty()) |
| return; |
| |
| Dialect dialect(nullptr); |
| for (const AttrOrTypeDef &typeDef : defs) { |
| if (!dialect) { |
| dialect = typeDef.getDialect(); |
| } else if (dialect != typeDef.getDialect()) { |
| llvm::PrintFatalError("defs belonging to more than one dialect. Must " |
| "select one via '--(attr|type)defs-dialect'"); |
| } |
| } |
| |
| dialectName = dialect.getName(); |
| } else { |
| dialectName = selectedDialect; |
| } |
| |
| for (const AttrOrTypeDef &def : defs) |
| if (def.getDialect().getName().equals(dialectName)) |
| resultDefs.push_back(def); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ParamCommaFormatter |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Pass an instance of this class to llvm::formatv() to emit a comma separated |
| /// list of parameters in the format by 'EmitFormat'. |
| class ParamCommaFormatter : public llvm::detail::format_adapter { |
| public: |
| /// Choose the output format |
| enum EmitFormat { |
| /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name, |
| /// [...]". |
| TypeNamePairs, |
| |
| /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". |
| TypeNameInitializer, |
| |
| /// Emit "param1Name, param2Name, [...]". |
| JustParams, |
| }; |
| |
| ParamCommaFormatter(EmitFormat emitFormat, |
| ArrayRef<AttrOrTypeParameter> params, |
| bool prependComma = true) |
| : emitFormat(emitFormat), params(params), prependComma(prependComma) {} |
| |
| /// llvm::formatv will call this function when using an instance as a |
| /// replacement value. |
| void format(raw_ostream &os, StringRef options) override { |
| if (!params.empty() && prependComma) |
| os << ", "; |
| |
| switch (emitFormat) { |
| case EmitFormat::TypeNamePairs: |
| interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { |
| emitTypeNamePair(p, os); |
| }); |
| break; |
| case EmitFormat::TypeNameInitializer: |
| interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { |
| emitTypeNameInitializer(p, os); |
| }); |
| break; |
| case EmitFormat::JustParams: |
| interleaveComma(params, os, |
| [&](const AttrOrTypeParameter &p) { os << p.getName(); }); |
| break; |
| } |
| } |
| |
| private: |
| // Emit "paramType paramName". |
| static void emitTypeNamePair(const AttrOrTypeParameter ¶m, |
| raw_ostream &os) { |
| os << param.getCppType() << " " << param.getName(); |
| } |
| // Emit "paramName(paramName)" |
| void emitTypeNameInitializer(const AttrOrTypeParameter ¶m, |
| raw_ostream &os) { |
| os << param.getName() << "(" << param.getName() << ")"; |
| } |
| |
| EmitFormat emitFormat; |
| ArrayRef<AttrOrTypeParameter> params; |
| bool prependComma; |
| }; |
| |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // DefGenerator |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| /// This struct is the base generator used when processing tablegen interfaces. |
| class DefGenerator { |
| public: |
| bool emitDecls(StringRef selectedDialect); |
| bool emitDefs(StringRef selectedDialect); |
| |
| protected: |
| DefGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os) |
| : defRecords(std::move(defs)), os(os), isAttrGenerator(false) {} |
| |
| /// Emit the declaration of a single def. |
| void emitDefDecl(const AttrOrTypeDef &def); |
| /// Emit the list of def type names. |
| void emitTypeDefList(ArrayRef<AttrOrTypeDef> defs); |
| /// Emit the code to dispatch between different defs during parsing/printing. |
| void emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs); |
| /// Emit the definition of a single def. |
| void emitDefDef(const AttrOrTypeDef &def); |
| /// Emit the storage class for the given def. |
| void emitStorageClass(const AttrOrTypeDef &def); |
| /// Emit the parser/printer for the given def. |
| void emitParsePrint(const AttrOrTypeDef &def); |
| |
| /// The set of def records to emit. |
| std::vector<llvm::Record *> defRecords; |
| /// The stream to emit to. |
| raw_ostream &os; |
| /// The prefix of the tablegen def name, e.g. Attr or Type. |
| StringRef defTypePrefix; |
| /// The C++ base value type of the def, e.g. Attribute or Type. |
| StringRef valueType; |
| /// Flag indicating if this generator is for Attributes. False if the |
| /// generator is for types. |
| bool isAttrGenerator; |
| }; |
| |
| /// A specialized generator for AttrDefs. |
| struct AttrDefGenerator : public DefGenerator { |
| AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
| : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) { |
| isAttrGenerator = true; |
| defTypePrefix = "Attr"; |
| valueType = "Attribute"; |
| } |
| }; |
| /// A specialized generator for TypeDefs. |
| struct TypeDefGenerator : public DefGenerator { |
| TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) |
| : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) { |
| defTypePrefix = "Type"; |
| valueType = "Type"; |
| } |
| }; |
| } // end anonymous namespace |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Declarations |
| //===----------------------------------------------------------------------===// |
| |
| /// Print this above all the other declarations. Contains type declarations used |
| /// later on. |
| static const char *const typeDefDeclHeader = R"( |
| namespace mlir { |
| class AsmParser; |
| class DialectAsmParser; |
| class AsmPrinter; |
| class DialectAsmPrinter; |
| } // namespace mlir |
| )"; |
| |
| /// The code block for the start of a typeDef class declaration -- singleton |
| /// case. |
| /// |
| /// {0}: The name of the def class. |
| /// {1}: The name of the type base class. |
| /// {2}: The name of the base value type, e.g. Attribute or Type. |
| /// {3}: The tablegen record type prefix, e.g. Attr or Type. |
| /// {4}: The traits of the def class. |
| static const char *const defDeclSingletonBeginStr = R"( |
| class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{ |
| public: |
| /// Inherit some necessary constructors from '{3}Base'. |
| using Base::Base; |
| )"; |
| |
| /// The code block for the start of a class declaration -- parametric case. |
| /// |
| /// {0}: The name of the def class. |
| /// {1}: The name of the base class. |
| /// {2}: The def storage class namespace. |
| /// {3}: The storage class name. |
| /// {4}: The name of the base value type, e.g. Attribute or Type. |
| /// {5}: The tablegen record type prefix, e.g. Attr or Type. |
| /// {6}: The traits of the def class. |
| static const char *const defDeclParametricBeginStr = R"( |
| namespace {2} { |
| struct {3}; |
| } // end namespace {2} |
| class {0} : public ::mlir::{4}::{5}Base<{0}, {1}, |
| {2}::{3}{6}> {{ |
| public: |
| /// Inherit some necessary constructors from '{5}Base'. |
| using Base::Base; |
| |
| )"; |
| |
| /// The code snippet for print/parse of an Attribute/Type. |
| /// |
| /// {0}: The name of the base value type, e.g. Attribute or Type. |
| /// {1}: Extra parser parameters. |
| static const char *const defDeclParsePrintStr = R"( |
| static ::mlir::{0} parse(::mlir::AsmParser &parser{1}); |
| void print(::mlir::AsmPrinter &printer) const; |
| )"; |
| |
| /// The code block for the verify method declaration. |
| /// |
| /// {0}: List of parameters, parameters style. |
| static const char *const defDeclVerifyStr = R"( |
| using Base::getChecked; |
| static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0}); |
| )"; |
| |
| /// Emit the builders for the given def. |
| static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os, |
| ParamCommaFormatter ¶mTypes) { |
| StringRef typeClass = def.getCppClassName(); |
| bool genCheckedMethods = def.genVerifyDecl(); |
| if (!def.skipDefaultBuilders()) { |
| os << llvm::formatv( |
| " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, |
| paramTypes); |
| if (genCheckedMethods) { |
| os << llvm::formatv(" static {0} " |
| "getChecked(llvm::function_ref<::mlir::" |
| "InFlightDiagnostic()> emitError, " |
| "::mlir::MLIRContext *context{1});\n", |
| typeClass, paramTypes); |
| } |
| } |
| |
| // Generate the builders specified by the user. |
| for (const AttrOrTypeBuilder &builder : def.getBuilders()) { |
| std::string paramStr; |
| llvm::raw_string_ostream paramOS(paramStr); |
| llvm::interleaveComma( |
| builder.getParameters(), paramOS, |
| [&](const AttrOrTypeBuilder::Parameter ¶m) { |
| // Note: AttrOrTypeBuilder parameters are guaranteed to have names. |
| paramOS << param.getCppType() << " " << *param.getName(); |
| if (Optional<StringRef> defaultParamValue = param.getDefaultValue()) |
| paramOS << " = " << *defaultParamValue; |
| }); |
| paramOS.flush(); |
| |
| // Generate the `get` variant of the builder. |
| os << " static " << typeClass << " get("; |
| if (!builder.hasInferredContextParameter()) { |
| os << "::mlir::MLIRContext *context"; |
| if (!paramStr.empty()) |
| os << ", "; |
| } |
| os << paramStr << ");\n"; |
| |
| // Generate the `getChecked` variant of the builder. |
| if (genCheckedMethods) { |
| os << " static " << typeClass |
| << " getChecked(llvm::function_ref<mlir::InFlightDiagnostic()> " |
| "emitError"; |
| if (!builder.hasInferredContextParameter()) |
| os << ", ::mlir::MLIRContext *context"; |
| if (!paramStr.empty()) |
| os << ", "; |
| os << paramStr << ");\n"; |
| } |
| } |
| } |
| |
| static void emitInterfaceMethodDecls(const InterfaceTrait *trait, |
| raw_ostream &os) { |
| Interface interface = trait->getInterface(); |
| |
| // Get the set of methods that should always be declared. |
| auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods(); |
| llvm::StringSet<> alwaysDeclaredMethods; |
| alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), |
| alwaysDeclaredMethodsVec.end()); |
| |
| for (const InterfaceMethod &method : interface.getMethods()) { |
| // Don't declare if the method has a body. |
| if (method.getBody()) |
| continue; |
| // Don't declare if the method has a default implementation and the def |
| // didn't request that it always be declared. |
| if (method.getDefaultImplementation() && |
| !alwaysDeclaredMethods.count(method.getName())) |
| continue; |
| |
| // Emit the method declaration. |
| os << " " << (method.isStatic() ? "static " : "") |
| << method.getReturnType() << " " << method.getName() << "("; |
| llvm::interleaveComma(method.getArguments(), os, |
| [&](const InterfaceMethod::Argument &arg) { |
| os << arg.type << " " << arg.name; |
| }); |
| os << ")" << (method.isStatic() ? "" : " const") << ";\n"; |
| } |
| } |
| |
| void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) { |
| SmallVector<AttrOrTypeParameter, 4> params; |
| def.getParameters(params); |
| |
| // Build the trait list for this def. |
| std::vector<std::string> traitList; |
| StringSet<> traitSet; |
| for (const Trait &baseTrait : def.getTraits()) { |
| std::string traitStr; |
| if (const auto *trait = dyn_cast<NativeTrait>(&baseTrait)) |
| traitStr = trait->getFullyQualifiedTraitName(); |
| else if (const auto *trait = dyn_cast<InterfaceTrait>(&baseTrait)) |
| traitStr = trait->getFullyQualifiedTraitName(); |
| else |
| llvm_unreachable("unexpected Attribute/Type trait type"); |
| |
| if (traitSet.insert(traitStr).second) |
| traitList.emplace_back(std::move(traitStr)); |
| } |
| std::string traitStr; |
| if (!traitList.empty()) |
| traitStr = ", " + llvm::join(traitList, ", "); |
| |
| // Emit the beginning string template: either the singleton or parametric |
| // template. |
| if (def.getNumParameters() == 0) { |
| os << formatv(defDeclSingletonBeginStr, def.getCppClassName(), |
| def.getCppBaseClassName(), valueType, defTypePrefix, |
| traitStr); |
| } else { |
| os << formatv(defDeclParametricBeginStr, def.getCppClassName(), |
| def.getCppBaseClassName(), def.getStorageNamespace(), |
| def.getStorageClassName(), valueType, defTypePrefix, |
| traitStr); |
| } |
| |
| // Emit the extra declarations first in case there's a definition in there. |
| if (Optional<StringRef> extraDecl = def.getExtraDecls()) |
| os << *extraDecl << "\n"; |
| |
| ParamCommaFormatter emitTypeNamePairsAfterComma( |
| ParamCommaFormatter::EmitFormat::TypeNamePairs, params); |
| if (!params.empty()) { |
| emitBuilderDecls(def, os, emitTypeNamePairsAfterComma); |
| |
| // Emit the verify invariants declaration. |
| if (def.genVerifyDecl()) |
| os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma); |
| } |
| |
| // Emit the mnenomic, if specified. |
| if (auto mnenomic = def.getMnemonic()) { |
| os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n" |
| << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n" |
| << " }\n"; |
| |
| // If mnemonic specified, emit print/parse declarations. |
| if (def.getParserCode() || def.getPrinterCode() || |
| def.getAssemblyFormat() || !params.empty()) { |
| os << llvm::formatv(defDeclParsePrintStr, valueType, |
| isAttrGenerator ? ", ::mlir::Type type" : ""); |
| } |
| } |
| |
| if (def.genAccessors()) { |
| SmallVector<AttrOrTypeParameter, 4> parameters; |
| def.getParameters(parameters); |
| |
| for (AttrOrTypeParameter ¶meter : parameters) { |
| os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(), |
| getParameterAccessorName(parameter.getName())); |
| } |
| } |
| |
| // Emit any interface method declarations. |
| for (const Trait &trait : def.getTraits()) { |
| if (const auto *traitDef = dyn_cast<InterfaceTrait>(&trait)) { |
| if (traitDef->shouldDeclareMethods()) |
| emitInterfaceMethodDecls(traitDef, os); |
| } |
| } |
| |
| // End the decl. |
| os << " };\n"; |
| } |
| |
| bool DefGenerator::emitDecls(StringRef selectedDialect) { |
| emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os); |
| IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); |
| |
| // Output the common "header". |
| os << typeDefDeclHeader; |
| |
| SmallVector<AttrOrTypeDef, 16> defs; |
| collectAllDefs(selectedDialect, defRecords, defs); |
| if (defs.empty()) |
| return false; |
| { |
| NamespaceEmitter nsEmitter(os, defs.front().getDialect()); |
| |
| // Declare all the def classes first (in case they reference each other). |
| for (const AttrOrTypeDef &def : defs) |
| os << " class " << def.getCppClassName() << ";\n"; |
| |
| // Emit the declarations. |
| for (const AttrOrTypeDef &def : defs) |
| emitDefDecl(def); |
| } |
| // Emit the TypeID explicit specializations to have a single definition for |
| // each of these. |
| for (const AttrOrTypeDef &def : defs) |
| if (!def.getDialect().getCppNamespace().empty()) |
| os << "DECLARE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() |
| << "::" << def.getCppClassName() << ")\n"; |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Def List |
| //===----------------------------------------------------------------------===// |
| |
| void DefGenerator::emitTypeDefList(ArrayRef<AttrOrTypeDef> defs) { |
| IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os); |
| auto interleaveFn = [&](const AttrOrTypeDef &def) { |
| os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); |
| }; |
| llvm::interleave(defs, os, interleaveFn, ",\n"); |
| os << "\n"; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Definitions |
| //===----------------------------------------------------------------------===// |
| |
| /// The code block used to start the auto-generated parser function. |
| /// |
| /// {0}: The name of the base value type, e.g. Attribute or Type. |
| /// {1}: Additional parser parameters. |
| static const char *const defParserDispatchStartStr = R"( |
| static ::mlir::OptionalParseResult generated{0}Parser( |
| ::mlir::AsmParser &parser, |
| ::llvm::StringRef mnemonic{1}, |
| ::mlir::{0} &value) {{ |
| )"; |
| |
| /// The code block for default attribute parser/printer dispatch boilerplate. |
| /// {0}: the dialect fully qualified class name. |
| static const char *const dialectDefaultAttrPrinterParserDispatch = R"( |
| /// Parse an attribute registered to this dialect. |
| ::mlir::Attribute {0}::parseAttribute(::mlir::DialectAsmParser &parser, |
| ::mlir::Type type) const {{ |
| ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
| ::llvm::StringRef attrTag; |
| if (::mlir::failed(parser.parseKeyword(&attrTag))) |
| return {{}; |
| {{ |
| ::mlir::Attribute attr; |
| auto parseResult = generatedAttributeParser(parser, attrTag, type, attr); |
| if (parseResult.hasValue()) |
| return attr; |
| } |
| parser.emitError(typeLoc) << "unknown attribute `" |
| << attrTag << "` in dialect `" << getNamespace() << "`"; |
| return {{}; |
| } |
| /// Print an attribute registered to this dialect. |
| void {0}::printAttribute(::mlir::Attribute attr, |
| ::mlir::DialectAsmPrinter &printer) const {{ |
| if (::mlir::succeeded(generatedAttributePrinter(attr, printer))) |
| return; |
| } |
| )"; |
| |
| /// The code block for default type parser/printer dispatch boilerplate. |
| /// {0}: the dialect fully qualified class name. |
| static const char *const dialectDefaultTypePrinterParserDispatch = R"( |
| /// Parse a type registered to this dialect. |
| ::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{ |
| ::llvm::SMLoc typeLoc = parser.getCurrentLocation(); |
| ::llvm::StringRef mnemonic; |
| if (parser.parseKeyword(&mnemonic)) |
| return ::mlir::Type(); |
| ::mlir::Type genType; |
| auto parseResult = generatedTypeParser(parser, mnemonic, genType); |
| if (parseResult.hasValue()) |
| return genType; |
| parser.emitError(typeLoc) << "unknown type `" |
| << mnemonic << "` in dialect `" << getNamespace() << "`"; |
| return {{}; |
| } |
| /// Print a type registered to this dialect. |
| void {0}::printType(::mlir::Type type, |
| ::mlir::DialectAsmPrinter &printer) const {{ |
| if (::mlir::succeeded(generatedTypePrinter(type, printer))) |
| return; |
| } |
| )"; |
| |
| /// The code block used to start the auto-generated printer function. |
| /// |
| /// {0}: The name of the base value type, e.g. Attribute or Type. |
| static const char *const defPrinterDispatchStartStr = R"( |
| static ::mlir::LogicalResult generated{0}Printer( |
| ::mlir::{0} def, ::mlir::AsmPrinter &printer) {{ |
| return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def) |
| )"; |
| |
| /// Beginning of storage class. |
| /// {0}: Storage class namespace. |
| /// {1}: Storage class c++ name. |
| /// {2}: Parameters parameters. |
| /// {3}: Parameter initializer string. |
| /// {4}: Parameter types. |
| /// {5}: The name of the base value type, e.g. Attribute or Type. |
| static const char *const defStorageClassBeginStr = R"( |
| namespace {0} {{ |
| struct {1} : public ::mlir::{5}Storage {{ |
| {1} ({2}) |
| : {3} {{ } |
| |
| /// The hash key is a tuple of the parameter types. |
| using KeyTy = std::tuple<{4}>; |
| )"; |
| |
| /// The storage class' constructor template. |
| /// |
| /// {0}: storage class name. |
| /// {1}: The name of the base value type, e.g. Attribute or Type. |
| static const char *const defStorageClassConstructorBeginStr = R"( |
| /// Define a construction method for creating a new instance of this |
| /// storage. |
| static {0} *construct(::mlir::{1}StorageAllocator &allocator, |
| const KeyTy &tblgenKey) {{ |
| )"; |
| |
| /// The storage class' constructor return template. |
| /// |
| /// {0}: storage class name. |
| /// {1}: list of parameters. |
| static const char *const defStorageClassConstructorEndStr = R"( |
| return new (allocator.allocate<{0}>()) |
| {0}({1}); |
| } |
| )"; |
| |
| /// Use tgfmt to emit custom allocation code for each parameter, if necessary. |
| static void emitStorageParameterAllocation(const AttrOrTypeDef &def, |
| raw_ostream &os) { |
| SmallVector<AttrOrTypeParameter> parameters; |
| def.getParameters(parameters); |
| FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); |
| for (AttrOrTypeParameter ¶meter : parameters) { |
| if (Optional<StringRef> allocCode = parameter.getAllocator()) { |
| fmtCtxt.withSelf(parameter.getName()); |
| fmtCtxt.addSubst("_dst", parameter.getName()); |
| os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n"; |
| } |
| } |
| } |
| |
| /// Builds a code block that initializes the attribute storage of 'def'. |
| /// Attribute initialization is separated from Type initialization given that |
| /// the Attribute also needs to initialize its self-type, which has multiple |
| /// means of initialization. |
| static std::string buildAttributeStorageParamInitializer( |
| const AttrOrTypeDef &def, ArrayRef<AttrOrTypeParameter> parameters) { |
| std::string paramInitializer; |
| llvm::raw_string_ostream paramOS(paramInitializer); |
| paramOS << "::mlir::AttributeStorage("; |
| |
| // If this is an attribute, we need to check for value type initialization. |
| Optional<size_t> selfParamIndex; |
| for (auto it : llvm::enumerate(parameters)) { |
| const auto *selfParam = dyn_cast<AttributeSelfTypeParameter>(&it.value()); |
| if (!selfParam) |
| continue; |
| if (selfParamIndex) { |
| llvm::PrintFatalError(def.getLoc(), |
| "Only one attribute parameter can be marked as " |
| "AttributeSelfTypeParameter"); |
| } |
| paramOS << selfParam->getName(); |
| selfParamIndex = it.index(); |
| } |
| |
| // If we didn't find a self param, but the def has a type builder we use that |
| // to construct the type. |
| if (!selfParamIndex) { |
| const AttrDef &attrDef = cast<AttrDef>(def); |
| if (Optional<StringRef> typeBuilder = attrDef.getTypeBuilder()) { |
| FmtContext fmtContext; |
| for (const AttrOrTypeParameter ¶m : parameters) |
| fmtContext.addSubst(("_" + param.getName()).str(), param.getName()); |
| paramOS << tgfmt(*typeBuilder, &fmtContext); |
| } |
| } |
| paramOS << ")"; |
| |
| // Append the parameters to the initializer. |
| for (auto it : llvm::enumerate(parameters)) |
| if (it.index() != selfParamIndex) |
| paramOS << llvm::formatv(", {0}({0})", it.value().getName()); |
| |
| return paramOS.str(); |
| } |
| |
| void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) { |
| SmallVector<AttrOrTypeParameter, 4> params; |
| def.getParameters(params); |
| |
| // Collect the parameter types. |
| auto parameterTypes = |
| llvm::map_range(params, [](const AttrOrTypeParameter ¶meter) { |
| return parameter.getCppType(); |
| }); |
| std::string parameterTypeList = llvm::join(parameterTypes, ", "); |
| |
| // Collect the parameter initializer. |
| std::string paramInitializer; |
| if (isAttrGenerator) { |
| paramInitializer = buildAttributeStorageParamInitializer(def, params); |
| |
| } else { |
| llvm::raw_string_ostream initOS(paramInitializer); |
| llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) { |
| initOS << llvm::formatv("{0}({0})", it.getName()); |
| }); |
| } |
| |
| // * Emit most of the storage class up until the hashKey body. |
| os << formatv( |
| defStorageClassBeginStr, def.getStorageNamespace(), |
| def.getStorageClassName(), |
| ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, |
| params, /*prependComma=*/false), |
| paramInitializer, parameterTypeList, valueType); |
| |
| // * Emit the comparison method. |
| os << " bool operator==(const KeyTy &tblgenKey) const {\n"; |
| for (auto it : llvm::enumerate(params)) { |
| os << " if (!("; |
| |
| // Build the comparator context. |
| bool isSelfType = isa<AttributeSelfTypeParameter>(it.value()); |
| FmtContext context; |
| context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName()) |
| .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)"); |
| |
| // Use the parameter specified comparator if possible, otherwise default to |
| // operator==. |
| Optional<StringRef> comparator = it.value().getComparator(); |
| os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context); |
| os << "))\n return false;\n"; |
| } |
| os << " return true;\n }\n"; |
| |
| // * Emit the haskKey method. |
| os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n"; |
| |
| // Extract each parameter from the key. |
| os << " return ::llvm::hash_combine("; |
| llvm::interleaveComma( |
| llvm::seq<unsigned>(0, params.size()), os, |
| [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; }); |
| os << ");\n }\n"; |
| |
| // * Emit the construct method. |
| |
| // If user wants to build the storage constructor themselves, declare it |
| // here and then they can write the definition elsewhere. |
| if (def.hasStorageCustomConstructor()) { |
| os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator " |
| "&allocator, const KeyTy &tblgenKey);\n", |
| def.getStorageClassName(), valueType); |
| |
| // Otherwise, generate one. |
| } else { |
| // First, unbox the parameters. |
| os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(), |
| valueType); |
| for (unsigned i = 0, e = params.size(); i < e; ++i) { |
| os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n", |
| params[i].getName(), i); |
| } |
| |
| // Second, reassign the parameter variables with allocation code, if it's |
| // specified. |
| emitStorageParameterAllocation(def, os); |
| |
| // Last, return an allocated copy. |
| auto parameterNames = llvm::map_range( |
| params, [](const auto ¶m) { return param.getName(); }); |
| os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(), |
| llvm::join(parameterNames, ", ")); |
| } |
| |
| // * Emit the parameters as storage class members. |
| for (const AttrOrTypeParameter ¶meter : params) { |
| // Attribute value types are not stored as fields in the storage. |
| if (!isa<AttributeSelfTypeParameter>(parameter)) |
| os << " " << parameter.getCppType() << " " << parameter.getName() |
| << ";\n"; |
| } |
| os << " };\n"; |
| |
| os << "} // namespace " << def.getStorageNamespace() << "\n"; |
| } |
| |
| void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) { |
| auto printerCode = def.getPrinterCode(); |
| auto parserCode = def.getParserCode(); |
| auto assemblyFormat = def.getAssemblyFormat(); |
| if (assemblyFormat && (printerCode || parserCode)) { |
| // Custom assembly format cannot be specified at the same time as either |
| // custom printer or parser code. |
| PrintFatalError(def.getLoc(), |
| def.getName() + ": assembly format cannot be specified at " |
| "the same time as printer or parser code"); |
| } |
| |
| // Generate a parser and printer based on the assembly format, if specified. |
| if (assemblyFormat) { |
| // A custom assembly format requires accessors to be generated for the |
| // generated printer. |
| if (!def.genAccessors()) { |
| PrintFatalError(def.getLoc(), |
| def.getName() + |
| ": the generated printer from 'assemblyFormat' " |
| "requires 'genAccessors' to be true"); |
| } |
| return generateAttrOrTypeFormat(def, os); |
| } |
| |
| // Emit the printer code, if specified. |
| if (printerCode) { |
| // Both the mnenomic and printerCode must be defined (for parity with |
| // parserCode). |
| os << "void " << def.getCppClassName() |
| << "::print(::mlir::AsmPrinter &printer) const {\n"; |
| if (printerCode->empty()) { |
| // If no code specified, emit error. |
| PrintFatalError(def.getLoc(), |
| def.getName() + |
| ": printer (if specified) must have non-empty code"); |
| } |
| FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer"); |
| os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n"; |
| } |
| |
| // Emit the parser code, if specified. |
| if (parserCode) { |
| FmtContext fmtCtxt; |
| fmtCtxt.addSubst("_parser", "parser") |
| .addSubst("_ctxt", "parser.getContext()"); |
| |
| // The mnenomic must be defined so the dispatcher knows how to dispatch. |
| os << llvm::formatv("::mlir::{0} {1}::parse(" |
| "::mlir::AsmParser &parser", |
| valueType, def.getCppClassName()); |
| if (isAttrGenerator) { |
| // Attributes also accept a type parameter instead of a context. |
| os << ", ::mlir::Type type"; |
| fmtCtxt.addSubst("_type", "type"); |
| } |
| os << ") {\n"; |
| |
| if (parserCode->empty()) { |
| PrintFatalError(def.getLoc(), |
| def.getName() + |
| ": parser (if specified) must have non-empty code"); |
| } |
| os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; |
| } |
| } |
| |
| /// Replace all instances of 'from' to 'to' in `str` and return the new string. |
| static std::string replaceInStr(std::string str, StringRef from, StringRef to) { |
| size_t pos = 0; |
| while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) |
| str.replace(pos, from.size(), to.data(), to.size()); |
| return str; |
| } |
| |
| /// Emit the builders for the given def. |
| static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os, |
| ArrayRef<AttrOrTypeParameter> params) { |
| bool genCheckedMethods = def.genVerifyDecl(); |
| StringRef className = def.getCppClassName(); |
| if (!def.skipDefaultBuilders()) { |
| os << llvm::formatv( |
| "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n" |
| " return Base::get(context{2});\n}\n", |
| className, |
| ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, |
| params), |
| ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, |
| params)); |
| if (genCheckedMethods) { |
| os << llvm::formatv( |
| "{0} {0}::getChecked(" |
| "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, " |
| "::mlir::MLIRContext *context{1}) {{\n" |
| " return Base::getChecked(emitError, context{2});\n}\n", |
| className, |
| ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, |
| params), |
| ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, |
| params)); |
| } |
| } |
| |
| auto builderFmtCtx = |
| FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get"); |
| auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get"); |
| auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context"); |
| |
| // Generate the builders specified by the user. |
| for (const AttrOrTypeBuilder &builder : def.getBuilders()) { |
| Optional<StringRef> body = builder.getBody(); |
| if (!body) |
| continue; |
| std::string paramStr; |
| llvm::raw_string_ostream paramOS(paramStr); |
| llvm::interleaveComma(builder.getParameters(), paramOS, |
| [&](const AttrOrTypeBuilder::Parameter ¶m) { |
| // Note: AttrOrTypeBuilder parameters are guaranteed |
| // to have names. |
| paramOS << param.getCppType() << " " |
| << *param.getName(); |
| }); |
| paramOS.flush(); |
| |
| // Emit the `get` variant of the builder. |
| os << llvm::formatv("{0} {0}::get(", className); |
| if (!builder.hasInferredContextParameter()) { |
| os << "::mlir::MLIRContext *context"; |
| if (!paramStr.empty()) |
| os << ", "; |
| os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, |
| tgfmt(*body, &builderFmtCtx).str()); |
| } else { |
| os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, |
| tgfmt(*body, &inferredCtxBuilderFmtCtx).str()); |
| } |
| |
| // Emit the `getChecked` variant of the builder. |
| if (genCheckedMethods) { |
| os << llvm::formatv("{0} " |
| "{0}::getChecked(llvm::function_ref<::mlir::" |
| "InFlightDiagnostic()> emitErrorFn", |
| className); |
| std::string checkedBody = |
| replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, "); |
| if (!builder.hasInferredContextParameter()) { |
| os << ", ::mlir::MLIRContext *context"; |
| checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str(); |
| } |
| if (!paramStr.empty()) |
| os << ", "; |
| os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody); |
| } |
| } |
| } |
| |
| /// Print all the def-specific definition code. |
| void DefGenerator::emitDefDef(const AttrOrTypeDef &def) { |
| NamespaceEmitter ns(os, def.getDialect()); |
| |
| SmallVector<AttrOrTypeParameter, 4> parameters; |
| def.getParameters(parameters); |
| if (!parameters.empty()) { |
| // Emit the storage class, if requested and necessary. |
| if (def.genStorageClass()) |
| emitStorageClass(def); |
| |
| // Emit the builders for this def. |
| emitBuilderDefs(def, os, parameters); |
| |
| // Generate accessor definitions only if we also generate the storage class. |
| // Otherwise, let the user define the exact accessor definition. |
| if (def.genAccessors() && def.genStorageClass()) { |
| for (const AttrOrTypeParameter ¶m : parameters) { |
| SmallString<32> paramStorageName; |
| if (isa<AttributeSelfTypeParameter>(param)) { |
| Twine("getType().cast<" + param.getCppType() + ">()") |
| .toVector(paramStorageName); |
| } else { |
| paramStorageName = param.getName(); |
| } |
| |
| os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n", |
| param.getCppAccessorType(), |
| getParameterAccessorName(param.getName()), |
| paramStorageName, def.getCppClassName()); |
| } |
| } |
| } |
| |
| // If mnemonic is specified maybe print definitions for the parser and printer |
| // code, if they're specified. |
| if (def.getMnemonic()) |
| emitParsePrint(def); |
| } |
| |
| /// Emit the dialect printer/parser dispatcher. User's code should call these |
| /// functions from their dialect's print/parse methods. |
| void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) { |
| if (llvm::none_of(defs, [](const AttrOrTypeDef &def) { |
| return def.getMnemonic().hasValue(); |
| })) { |
| return; |
| } |
| |
| // The parser dispatch is just a list of if-elses, matching on the mnemonic |
| // and calling the def's parse function. |
| os << llvm::formatv(defParserDispatchStartStr, valueType, |
| isAttrGenerator ? ", ::mlir::Type type" : ""); |
| for (const AttrOrTypeDef &def : defs) { |
| if (def.getMnemonic()) { |
| os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) { \n" |
| " value = {0}::{1}::", |
| def.getDialect().getCppNamespace(), def.getCppClassName()); |
| |
| // If the def has no parameters and no parser code, just invoke a normal |
| // `get`. |
| if (def.getNumParameters() == 0 && !def.getParserCode()) { |
| os << "get(parser.getContext());\n"; |
| os << " return ::mlir::success(!!value);\n }\n"; |
| continue; |
| } |
| |
| os << "parse(parser" << (isAttrGenerator ? ", type" : "") |
| << ");\n return ::mlir::success(!!value);\n }\n"; |
| } |
| } |
| os << " return {};\n"; |
| os << "}\n\n"; |
| |
| // The printer dispatch uses llvm::TypeSwitch to find and call the correct |
| // printer. |
| os << llvm::formatv(defPrinterDispatchStartStr, valueType); |
| for (const AttrOrTypeDef &def : defs) { |
| Optional<StringRef> mnemonic = def.getMnemonic(); |
| if (!mnemonic) |
| continue; |
| |
| StringRef cppNamespace = def.getDialect().getCppNamespace(); |
| StringRef cppClassName = def.getCppClassName(); |
| os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", |
| cppNamespace, cppClassName); |
| |
| os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, |
| cppClassName); |
| // If the def has no parameters and no printer, just print the mnemonic. |
| if (def.getNumParameters() != 0 || def.getPrinterCode()) |
| os << "t.print(printer);"; |
| os << "\n return ::mlir::success();\n })\n"; |
| } |
| os << llvm::formatv( |
| " .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n", |
| valueType); |
| } |
| |
| bool DefGenerator::emitDefs(StringRef selectedDialect) { |
| emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os); |
| |
| SmallVector<AttrOrTypeDef, 16> defs; |
| collectAllDefs(selectedDialect, defRecords, defs); |
| if (defs.empty()) |
| return false; |
| emitTypeDefList(defs); |
| |
| IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); |
| emitParsePrintDispatch(defs); |
| for (const AttrOrTypeDef &def : defs) { |
| emitDefDef(def); |
| // Emit the TypeID explicit specializations to have a single symbol def. |
| if (!def.getDialect().getCppNamespace().empty()) |
| os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() |
| << "::" << def.getCppClassName() << ")\n"; |
| } |
| |
| Dialect firstDialect = defs.front().getDialect(); |
| // Emit the default parser/printer for Attributes if the dialect asked for |
| // it. |
| if (valueType == "Attribute" && |
| firstDialect.useDefaultAttributePrinterParser()) { |
| NamespaceEmitter nsEmitter(os, firstDialect); |
| os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, |
| firstDialect.getCppClassName()); |
| } |
| |
| // Emit the default parser/printer for Types if the dialect asked for it. |
| if (valueType == "Type" && firstDialect.useDefaultTypePrinterParser()) { |
| NamespaceEmitter nsEmitter(os, firstDialect); |
| os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, |
| firstDialect.getCppClassName()); |
| } |
| |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // GEN: Registration hooks |
| //===----------------------------------------------------------------------===// |
| |
| //===----------------------------------------------------------------------===// |
| // AttrDef |
| |
| static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*"); |
| static llvm::cl::opt<std::string> |
| attrDialect("attrdefs-dialect", |
| llvm::cl::desc("Generate attributes for this dialect"), |
| llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated); |
| |
| static mlir::GenRegistration |
| genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", |
| [](const llvm::RecordKeeper &records, raw_ostream &os) { |
| AttrDefGenerator generator(records, os); |
| return generator.emitDefs(attrDialect); |
| }); |
| static mlir::GenRegistration |
| genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", |
| [](const llvm::RecordKeeper &records, raw_ostream &os) { |
| AttrDefGenerator generator(records, os); |
| return generator.emitDecls(attrDialect); |
| }); |
| |
| //===----------------------------------------------------------------------===// |
| // TypeDef |
| |
| static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); |
| static llvm::cl::opt<std::string> |
| typeDialect("typedefs-dialect", |
| llvm::cl::desc("Generate types for this dialect"), |
| llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); |
| |
| static mlir::GenRegistration |
| genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", |
| [](const llvm::RecordKeeper &records, raw_ostream &os) { |
| TypeDefGenerator generator(records, os); |
| return generator.emitDefs(typeDialect); |
| }); |
| static mlir::GenRegistration |
| genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", |
| [](const llvm::RecordKeeper &records, raw_ostream &os) { |
| TypeDefGenerator generator(records, os); |
| return generator.emitDecls(typeDialect); |
| }); |