blob: e33b74b93e1faa1c06c7bb686785acf9e169cd45 [file] [log] [blame]
//===- StructsGen.cpp - MLIR struct utility generator ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// StructsGen generates common utility functions for grouping attributes into a
// set of structured data.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::FmtContext;
using mlir::tblgen::StructAttr;
static void
emitStructClass(const Record &structDef, StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
StringRef description, raw_ostream &os) {
const char *structInfo = R"(
// {0}
class {1} : public ::mlir::DictionaryAttr)";
const char *structInfoEnd = R"( {
public:
using ::mlir::DictionaryAttr::DictionaryAttr;
static bool classof(::mlir::Attribute attr);
)";
os << formatv(structInfo, description, structName) << structInfoEnd;
// Declares a constructor function for the tablegen structure.
// TblgenStruct::get(MLIRContext context, Type1 Field1, Type2 Field2, ...);
const char *getInfoDecl = " static {0} get(\n";
const char *getInfoDeclArg = " {0} {1},\n";
const char *getInfoDeclEnd = " ::mlir::MLIRContext* context);\n\n";
os << llvm::formatv(getInfoDecl, structName);
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(getInfoDeclArg, storage, name);
}
os << getInfoDeclEnd;
// Declares an accessor for the fields owned by the tablegen structure.
// namespace::storage TblgenStruct::field1() const;
const char *fieldInfo = R"( {0} {1}() const;
)";
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << formatv(fieldInfo, storage, name);
}
os << "};\n\n";
}
static void emitStructDecl(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
StringRef structName = structAttr.getStructClassName();
StringRef cppNamespace = structAttr.getCppNamespace();
StringRef description = structAttr.getSummary();
auto fields = structAttr.getAllFields();
// Wrap in the appropriate namespace.
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
// Emit the struct class definition
emitStructClass(structDef, structName, fields, description, os);
// Close the declared namespace.
for (auto ns : namespaces)
os << "} // namespace " << ns << "\n";
}
static bool emitStructDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Utility Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
for (const auto *def : defs) {
emitStructDecl(*def, os);
}
return false;
}
static void emitFactoryDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *getInfoDecl = "{0} {0}::get(\n";
const char *getInfoDeclArg = " {0} {1},\n";
const char *getInfoDeclEnd = " ::mlir::MLIRContext* context) {";
os << llvm::formatv(getInfoDecl, structName);
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
os << llvm::formatv(getInfoDeclArg, storage, name);
}
os << getInfoDeclEnd;
const char *fieldStart = R"(
::llvm::SmallVector<::mlir::NamedAttribute, {0}> fields;
)";
os << llvm::formatv(fieldStart, fields.size());
const char *getFieldInfo = R"(
assert({0});
auto {0}_id = ::mlir::StringAttr::get(context, "{0}");
fields.emplace_back({0}_id, {0});
)";
const char *getFieldInfoOptional = R"(
if ({0}) {
auto {0}_id = ::mlir::StringAttr::get(context, "{0}");
fields.emplace_back({0}_id, {0});
}
)";
for (auto field : fields) {
if (field.getType().isOptional() || field.getType().hasDefaultValue())
os << llvm::formatv(getFieldInfoOptional, field.getName());
else
os << llvm::formatv(getFieldInfo, field.getName());
}
const char *getEndInfo = R"(
::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields);
return dict.dyn_cast<{0}>();
}
)";
os << llvm::formatv(getEndInfo, structName);
}
static void emitClassofDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *classofInfo = R"(
bool {0}::classof(::mlir::Attribute attr))";
const char *classofInfoHeader = R"(
if (!attr)
return false;
auto derived = attr.dyn_cast<::mlir::DictionaryAttr>();
if (!derived)
return false;
int num_absent_attrs = 0;
)";
os << llvm::formatv(classofInfo, structName) << " {";
os << llvm::formatv(classofInfoHeader);
FmtContext fctx;
const char *classofArgInfo = R"(
auto {0} = derived.get("{0}");
if (!{0} || !({1}))
return false;
)";
const char *classofArgInfoOptional = R"(
auto {0} = derived.get("{0}");
if (!{0})
++num_absent_attrs;
else if (!({1}))
return false;
)";
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
std::string condition =
std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
if (type.isOptional() || type.hasDefaultValue())
os << llvm::formatv(classofArgInfoOptional, name, condition);
else
os << llvm::formatv(classofArgInfo, name, condition);
}
const char *classofEndInfo = R"(
return derived.size() + num_absent_attrs == {0};
}
)";
os << llvm::formatv(classofEndInfo, fields.size());
}
static void
emitAccessorDef(llvm::StringRef structName,
llvm::ArrayRef<mlir::tblgen::StructFieldAttr> fields,
raw_ostream &os) {
const char *fieldInfo = R"(
{0} {2}::{1}() const {
auto derived = this->cast<::mlir::DictionaryAttr>();
auto {1} = derived.get("{1}");
assert({1} && "attribute not found.");
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
return {1}.cast<{0}>();
}
)";
const char *fieldInfoOptional = R"(
{0} {2}::{1}() const {
auto derived = this->cast<::mlir::DictionaryAttr>();
auto {1} = derived.get("{1}");
if (!{1})
return nullptr;
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
return {1}.cast<{0}>();
}
)";
const char *fieldInfoDefaultValued = R"(
{0} {2}::{1}() const {
auto derived = this->cast<::mlir::DictionaryAttr>();
auto {1} = derived.get("{1}");
if (!{1}) {
::mlir::Builder builder(getContext());
return {3};
}
assert({1}.isa<{0}>() && "incorrect Attribute type found.");
return {1}.cast<{0}>();
}
)";
FmtContext fmtCtx;
fmtCtx.withBuilder("builder");
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
if (type.isOptional()) {
os << llvm::formatv(fieldInfoOptional, storage, name, structName);
} else if (type.hasDefaultValue()) {
std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx,
type.getDefaultValue());
os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName,
defaultValue);
} else {
os << llvm::formatv(fieldInfo, storage, name, structName);
}
}
}
static void emitStructDef(const Record &structDef, raw_ostream &os) {
StructAttr structAttr(&structDef);
StringRef cppNamespace = structAttr.getCppNamespace();
StringRef structName = structAttr.getStructClassName();
mlir::tblgen::FmtContext ctx;
auto fields = structAttr.getAllFields();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
emitFactoryDef(structName, fields, os);
emitClassofDef(structName, fields, os);
emitAccessorDef(structName, fields, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
}
static bool emitStructDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Struct Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("StructAttr");
for (const auto *def : defs)
emitStructDef(*def, os);
return false;
}
// Registers the struct utility generator to mlir-tblgen.
static mlir::GenRegistration
genStructDecls("gen-struct-attr-decls",
"Generate struct utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitStructDecls(records, os);
});
// Registers the struct utility generator to mlir-tblgen.
static mlir::GenRegistration
genStructDefs("gen-struct-attr-defs", "Generate struct utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitStructDefs(records, os);
});