blob: 2a7406f42f34b5b0b1efd7968847e3858e2b1d54 [file] [log] [blame]
//===- OpInterfacesGen.cpp - MLIR op interface utility generator ----------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// OpInterfacesGen generates definitions for operation interfaces.
#include "DocGenUtilities.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Interfaces.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace mlir;
using mlir::tblgen::Interface;
using mlir::tblgen::InterfaceMethod;
using mlir::tblgen::OpInterface;
/// Emit a string corresponding to a C++ type, followed by a space if necessary.
static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
type = type.trim();
os << type;
if (type.back() != '&' && type.back() != '*')
os << " ";
return os;
/// Emit the method name and argument list for the given method. If 'addThisArg'
/// is true, then an argument is added to the beginning of the argument list for
/// the concrete value.
static void emitMethodNameAndArgs(const InterfaceMethod &method,
raw_ostream &os, StringRef valueType,
bool addThisArg, bool addConst) {
os << method.getName() << '(';
if (addThisArg) {
if (addConst)
os << "const ";
os << "const Concept *impl, ";
emitCPPType(valueType, os)
<< "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
os << arg.type << " " <<;
os << ')';
if (addConst)
os << " const";
/// Get an array of all OpInterface definitions but exclude those subclassing
/// "DeclareOpInterfaceMethods".
static std::vector<llvm::Record *>
getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
StringRef name) {
std::vector<llvm::Record *> defs =
recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
std::string declareName = ("Declare" + name + "InterfaceMethods").str();
llvm::erase_if(defs, [&](const llvm::Record *def) {
// Ignore any "declare methods" interfaces.
if (def->isSubClassOf(declareName))
return true;
// Ignore interfaces defined outside of the top-level file.
return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
return defs;
namespace {
/// This struct is the base generator used when processing tablegen interfaces.
class InterfaceGenerator {
bool emitInterfaceDefs();
bool emitInterfaceDecls();
bool emitInterfaceDocs();
InterfaceGenerator(std::vector<llvm::Record *> &&defs, raw_ostream &os)
: defs(std::move(defs)), os(os) {}
void emitConceptDecl(const Interface &interface);
void emitModelDecl(const Interface &interface);
void emitModelMethodsDef(const Interface &interface);
void emitTraitDecl(const Interface &interface, StringRef interfaceName,
StringRef interfaceTraitsName);
void emitInterfaceDecl(const Interface &interface);
/// The set of interface records to emit.
std::vector<llvm::Record *> defs;
// The stream to emit to.
raw_ostream &os;
/// The C++ value type of the interface, e.g. Operation*.
StringRef valueType;
/// The C++ base interface type.
StringRef interfaceBaseType;
/// The name of the typename for the value template.
StringRef valueTemplate;
/// The name of the substituion variable for the value.
StringRef substVar;
/// The format context to use for methods.
tblgen::FmtContext nonStaticMethodFmt;
tblgen::FmtContext traitMethodFmt;
tblgen::FmtContext extraDeclsFmt;
/// A specialized generator for attribute interfaces.
struct AttrInterfaceGenerator : public InterfaceGenerator {
AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
valueType = "::mlir::Attribute";
interfaceBaseType = "AttributeInterface";
valueTemplate = "ConcreteAttr";
substVar = "_attr";
StringRef castCode = "(::llvm::cast<ConcreteAttr>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
"(*static_cast<const ConcreteAttr *>(this))");
extraDeclsFmt.addSubst(substVar, "(*this)");
/// A specialized generator for operation interfaces.
struct OpInterfaceGenerator : public InterfaceGenerator {
OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
valueType = "::mlir::Operation *";
interfaceBaseType = "OpInterface";
valueTemplate = "ConcreteOp";
substVar = "_op";
StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst("_this", "impl")
.addSubst(substVar, castCode)
traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
extraDeclsFmt.addSubst(substVar, "(*this)");
/// A specialized generator for type interfaces.
struct TypeInterfaceGenerator : public InterfaceGenerator {
TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
: InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
valueType = "::mlir::Type";
interfaceBaseType = "TypeInterface";
valueTemplate = "ConcreteType";
substVar = "_type";
StringRef castCode = "(::llvm::cast<ConcreteType>(tablegen_opaque_val))";
nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
"(*static_cast<const ConcreteType *>(this))");
extraDeclsFmt.addSubst(substVar, "(*this)");
} // namespace
// GEN: Interface definitions
static void emitInterfaceMethodDoc(const InterfaceMethod &method,
raw_ostream &os, StringRef prefix = "") {
if (std::optional<StringRef> description = method.getDescription())
tblgen::emitDescriptionComment(*description, os, prefix);
static void emitInterfaceDefMethods(StringRef interfaceQualName,
const Interface &interface,
StringRef valueType, const Twine &implValue,
raw_ostream &os, bool isOpInterface) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
os << interfaceQualName << "::";
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
// Forward to the method on the concrete operation type.
os << " {\n return " << implValue << "->" << method.getName() << '(';
if (!method.isStatic()) {
os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
os << (method.arg_empty() ? "" : ", ");
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os <<; });
os << ");\n }\n";
static void emitInterfaceDef(const Interface &interface, StringRef valueType,
raw_ostream &os) {
std::string interfaceQualNameStr = interface.getFullyQualifiedName();
StringRef interfaceQualName = interfaceQualNameStr;
// Insert the method definitions.
bool isOpInterface = isa<OpInterface>(interface);
emitInterfaceDefMethods(interfaceQualName, interface, valueType, "getImpl()",
os, isOpInterface);
// Insert the method definitions for base classes.
for (auto &base : interface.getBaseInterfaces()) {
emitInterfaceDefMethods(interfaceQualName, base, valueType,
"getImpl()->impl" + base.getName(), os,
bool InterfaceGenerator::emitInterfaceDefs() {
llvm::emitSourceFileHeader("Interface Definitions", os);
for (const auto *def : defs)
emitInterfaceDef(Interface(def), valueType, os);
return false;
// GEN: Interface declarations
void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
os << " struct Concept {\n";
// Insert each of the pure virtual concept methods.
os << " /// The methods defined by the interface.\n";
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
os << "(*" << method.getName() << ")(";
if (!method.isStatic()) {
os << "const Concept *impl, ";
emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os << arg.type; });
os << ");\n";
// Insert a field containing a concept for each of the base interfaces.
auto baseInterfaces = interface.getBaseInterfaces();
if (!baseInterfaces.empty()) {
os << " /// The base classes of this interface.\n";
for (const auto &base : interface.getBaseInterfaces()) {
os << " const " << base.getFullyQualifiedName() << "::Concept *impl"
<< base.getName() << " = nullptr;\n";
// Define an "initialize" method that allows for the initialization of the
// base class concepts.
os << "\n void initializeInterfaceConcept(::mlir::detail::InterfaceMap "
"&interfaceMap) {\n";
std::string interfaceQualName = interface.getFullyQualifiedName();
for (const auto &base : interface.getBaseInterfaces()) {
StringRef baseName = base.getName();
std::string baseQualName = base.getFullyQualifiedName();
os << " impl" << baseName << " = interfaceMap.lookup<"
<< baseQualName << ">();\n"
<< " assert(impl" << baseName << " && \"`" << interfaceQualName
<< "` expected its base interface `" << baseQualName
<< "` to be registered\");\n";
os << " }\n";
os << " };\n";
void InterfaceGenerator::emitModelDecl(const Interface &interface) {
// Emit the basic model and the fallback model.
for (const char *modelClass : {"Model", "FallbackModel"}) {
os << " template<typename " << valueTemplate << ">\n";
os << " class " << modelClass << " : public Concept {\n public:\n";
os << " using Interface = " << interface.getFullyQualifiedName()
<< ";\n";
os << " " << modelClass << "() : Concept{";
interface.getMethods(), os,
[&](const InterfaceMethod &method) { os << method.getName(); });
os << "} {}\n\n";
// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " static inline ");
emitMethodNameAndArgs(method, os, valueType,
os << ";\n";
os << " };\n";
// Emit the template for the external model.
os << " template<typename ConcreteModel, typename " << valueTemplate
<< ">\n";
os << " class ExternalModel : public FallbackModel<ConcreteModel> {\n";
os << " public:\n";
os << " using ConcreteEntity = " << valueTemplate << ";\n";
// Emit declarations for methods that have default implementations. Other
// methods are expected to be implemented by the concrete derived model.
for (auto &method : interface.getMethods()) {
if (!method.getDefaultImplementation())
os << " ";
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os);
os << method.getName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
if (!method.arg_empty())
os << ", ";
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os);
os <<;
os << ")";
if (!method.isStatic())
os << " const";
os << ";\n";
os << " };\n";
void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";
for (auto &method : interface.getMethods()) {
os << "template<typename " << valueTemplate << ">\n";
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
<< valueTemplate << ">::";
emitMethodNameAndArgs(method, os, valueType,
os << " {\n ";
// Check for a provided body to the function.
if (std::optional<StringRef> body = method.getBody()) {
if (method.isStatic())
os << body->trim();
os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
os << "\n}\n";
// Forward to the method on the concrete operation type.
if (method.isStatic())
os << "return " << valueTemplate << "::";
os << tblgen::tgfmt("return $_self.", &nonStaticMethodFmt);
// Add the arguments to the call.
os << method.getName() << '(';
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os <<; });
os << ");\n}\n";
for (auto &method : interface.getMethods()) {
os << "template<typename " << valueTemplate << ">\n";
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
<< valueTemplate << ">::";
emitMethodNameAndArgs(method, os, valueType,
os << " {\n ";
// Forward to the method on the concrete Model implementation.
if (method.isStatic())
os << "return " << valueTemplate << "::";
os << "return static_cast<const " << valueTemplate << " *>(impl)->";
// Add the arguments to the call.
os << method.getName() << '(';
if (!method.isStatic())
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) { os <<; });
os << ");\n}\n";
// Emit default implementations for the external model.
for (auto &method : interface.getMethods()) {
if (!method.getDefaultImplementation())
os << "template<typename ConcreteModel, typename " << valueTemplate
<< ">\n";
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName()
<< "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
<< ">::";
os << method.getName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
if (!method.arg_empty())
os << ", ";
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os);
os <<;
os << ")";
if (!method.isStatic())
os << " const";
os << " {\n";
// Use the empty context for static methods.
tblgen::FmtContext ctx;
os << tblgen::tgfmt(method.getDefaultImplementation()->trim(),
method.isStatic() ? &ctx : &nonStaticMethodFmt);
os << "\n}\n";
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
void InterfaceGenerator::emitTraitDecl(const Interface &interface,
StringRef interfaceName,
StringRef interfaceTraitsName) {
os << llvm::formatv(" template <typename {3}>\n"
" struct {0}Trait : public ::mlir::{2}<{0},"
" detail::{1}>::Trait<{3}> {{\n",
interfaceName, interfaceTraitsName, interfaceBaseType,
// Insert the default implementation for any methods.
bool isOpInterface = isa<OpInterface>(interface);
for (auto &method : interface.getMethods()) {
// Flag interface methods named verifyTrait.
if (method.getName() == "verifyTrait")
formatv("'verifyTrait' method cannot be specified as interface "
"method for '{0}'; use the 'verify' field instead",
auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl)
emitInterfaceMethodDoc(method, os, " ");
os << " " << (method.isStatic() ? "static " : "");
emitCPPType(method.getReturnType(), os);
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
/*addConst=*/!isOpInterface && !method.isStatic());
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
<< "\n }\n";
if (auto verify = interface.getVerify()) {
assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
tblgen::FmtContext verifyCtx;
verifyCtx.addSubst("_op", "op");
os << llvm::formatv(
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
(interface.verifyWithRegions() ? "verifyRegionTrait"
: "verifyTrait"))
<< "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx)
<< "\n }\n";
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
if (auto extraTraitDecls = interface.getExtraSharedClassDeclaration())
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";
os << " };\n";
static void emitInterfaceDeclMethods(const Interface &interface,
raw_ostream &os, StringRef valueType,
bool isOpInterface,
tblgen::FmtContext &extraDeclsFmt) {
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
os << ";\n";
// Emit any extra declarations.
if (std::optional<StringRef> extraDecls =
os << extraDecls->rtrim() << "\n";
if (std::optional<StringRef> extraDecls =
os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n";
void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
for (StringRef ns : namespaces)
os << "namespace " << ns << " {\n";
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
// Emit a forward declaration of the interface class so that it becomes usable
// in the signature of its methods.
os << "class " << interfaceName << ";\n";
// Emit the traits struct containing the concept and model declarations.
os << "namespace detail {\n"
<< "struct " << interfaceTraitsName << " {\n";
os << "};\n";
// Emit the derived trait for the interface.
os << "template <typename " << valueTemplate << ">\n";
os << "struct " << interface.getName() << "Trait;\n";
os << "\n} // namespace detail\n";
// Emit the main interface class declaration.
os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n"
" using ::mlir::{3}<{1}, detail::{2}>::{3};\n",
interfaceName, interfaceName, interfaceTraitsName,
// Emit a utility wrapper trait class.
os << llvm::formatv(" template <typename {1}>\n"
" struct Trait : public detail::{0}Trait<{1}> {{};\n",
interfaceName, valueTemplate);
// Insert the method declarations.
bool isOpInterface = isa<OpInterface>(interface);
emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
// Insert the method declarations for base classes.
for (auto &base : interface.getBaseInterfaces()) {
std::string baseQualName = base.getFullyQualifiedName();
os << " //"
<< " // Inherited from " << baseQualName << "\n"
<< " //"
// Allow implicit conversion to the base interface.
os << " operator " << baseQualName << " () const {\n"
<< " if (!*this) return nullptr;\n"
<< " return " << baseQualName << "(*this, getImpl()->impl"
<< base.getName() << ");\n"
<< " }\n\n";
// Inherit the base interface's methods.
emitInterfaceDeclMethods(base, os, valueType, isOpInterface, extraDeclsFmt);
// Emit classof code if necessary.
if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
auto extraClassOfFmt = tblgen::FmtContext();
extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
os << " static bool classof(" << valueType << " base) {\n"
<< " auto* interface = getInterfaceFor(base);\n"
<< " if (!interface)\n"
" return false;\n"
" " << interfaceName << " odsInterfaceInstance(base, interface);\n"
<< " " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
<< "\n }\n";
os << "};\n";
os << "namespace detail {\n";
emitTraitDecl(interface, interfaceName, interfaceTraitsName);
os << "}// namespace detail\n";
for (StringRef ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
bool InterfaceGenerator::emitInterfaceDecls() {
llvm::emitSourceFileHeader("Interface Declarations", os);
// Sort according to ID, so defs are emitted in the order in which they appear
// in the Tablegen file.
std::vector<llvm::Record *> sortedDefs(defs);
llvm::sort(sortedDefs, [](llvm::Record *lhs, llvm::Record *rhs) {
return lhs->getID() < rhs->getID();
for (const llvm::Record *def : sortedDefs)
for (const llvm::Record *def : sortedDefs)
return false;
// GEN: Interface documentation
static void emitInterfaceDoc(const llvm::Record &interfaceDef,
raw_ostream &os) {
Interface interface(&interfaceDef);
// Emit the interface name followed by the description.
os << "## " << interface.getName() << " (`" << interfaceDef.getName()
<< "`)\n\n";
if (auto description = interface.getDescription())
mlir::tblgen::emitDescription(*description, os);
// Emit the methods required by the interface.
os << "\n### Methods:\n";
for (const auto &method : interface.getMethods()) {
// Emit the method name.
os << "#### `" << method.getName() << "`\n\n```c++\n";
// Emit the method signature.
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os) << method.getName() << '(';
llvm::interleaveComma(method.getArguments(), os,
[&](const InterfaceMethod::Argument &arg) {
emitCPPType(arg.type, os) <<;
os << ");\n```\n";
// Emit the description.
if (auto description = method.getDescription())
mlir::tblgen::emitDescription(*description, os);
// If the body is not provided, this method must be provided by the user.
if (!method.getBody())
os << "\nNOTE: This method *must* be implemented by the user.";
os << "\n\n";
bool InterfaceGenerator::emitInterfaceDocs() {
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
os << "# " << interfaceBaseType << " definitions\n";
for (const auto *def : defs)
emitInterfaceDoc(*def, os);
return false;
// GEN: Interface registration hooks
namespace {
template <typename GeneratorT>
struct InterfaceGenRegistration {
InterfaceGenRegistration(StringRef genArg, StringRef genDesc)
: genDeclArg(("gen-" + genArg + "-interface-decls").str()),
genDefArg(("gen-" + genArg + "-interface-defs").str()),
genDocArg(("gen-" + genArg + "-interface-docs").str()),
genDeclDesc(("Generate " + genDesc + " interface declarations").str()),
genDefDesc(("Generate " + genDesc + " interface definitions").str()),
genDocDesc(("Generate " + genDesc + " interface documentation").str()),
genDecls(genDeclArg, genDeclDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDecls();
genDefs(genDefArg, genDefDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDefs();
genDocs(genDocArg, genDocDesc,
[](const llvm::RecordKeeper &records, raw_ostream &os) {
return GeneratorT(records, os).emitInterfaceDocs();
}) {}
std::string genDeclArg, genDefArg, genDocArg;
std::string genDeclDesc, genDefDesc, genDocDesc;
mlir::GenRegistration genDecls, genDefs, genDocs;
} // namespace
static InterfaceGenRegistration<AttrInterfaceGenerator> attrGen("attr",
static InterfaceGenRegistration<OpInterfaceGenerator> opGen("op", "op");
static InterfaceGenRegistration<TypeInterfaceGenerator> typeGen("type", "type");