blob: aa8841abfd73b2b4c14d7f7388b85c0bd9d8ca7b [file] [log] [blame]
//===- EnumsGen.cpp - MLIR enum utility generator -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// EnumsGen generates common utility functions for enums.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::formatv;
using llvm::isDigit;
using llvm::PrintFatalError;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::EnumAttrCase;
using mlir::tblgen::FmtContext;
using mlir::tblgen::tgfmt;
static std::string makeIdentifier(StringRef str) {
if (!str.empty() && isDigit(static_cast<unsigned char>(str.front()))) {
std::string newStr = std::string("_") + str.str();
return newStr;
}
return str.str();
}
static void emitEnumClass(const Record &enumDef, StringRef enumName,
StringRef underlyingType, StringRef description,
const std::vector<EnumAttrCase> &enumerants,
raw_ostream &os) {
os << "// " << description << "\n";
os << "enum class " << enumName;
if (!underlyingType.empty())
os << " : " << underlyingType;
os << " {\n";
for (const auto &enumerant : enumerants) {
auto symbol = makeIdentifier(enumerant.getSymbol());
auto value = enumerant.getValue();
if (value >= 0) {
os << formatv(" {0} = {1},\n", symbol, value);
} else {
os << formatv(" {0},\n", symbol);
}
}
os << "};\n\n";
}
static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
StringRef cppNamespace, raw_ostream &os) {
std::string qualName =
std::string(formatv("{0}::{1}", cppNamespace, enumName));
if (underlyingType.empty())
underlyingType =
std::string(formatv("std::underlying_type<{0}>::type", qualName));
const char *const mapInfo = R"(
namespace llvm {
template<> struct DenseMapInfo<{0}> {{
using StorageInfo = ::llvm::DenseMapInfo<{1}>;
static inline {0} getEmptyKey() {{
return static_cast<{0}>(StorageInfo::getEmptyKey());
}
static inline {0} getTombstoneKey() {{
return static_cast<{0}>(StorageInfo::getTombstoneKey());
}
static unsigned getHashValue(const {0} &val) {{
return StorageInfo::getHashValue(static_cast<{1}>(val));
}
static bool isEqual(const {0} &lhs, const {0} &rhs) {{
return lhs == rhs;
}
};
})";
os << formatv(mapInfo, qualName, underlyingType);
os << "\n\n";
}
static void emitMaxValueFn(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName();
auto enumerants = enumAttr.getAllCases();
unsigned maxEnumVal = 0;
for (const auto &enumerant : enumerants) {
int64_t value = enumerant.getValue();
// Avoid generating the max value function if there is an enumerant without
// explicit value.
if (value < 0)
return;
maxEnumVal = std::max(maxEnumVal, static_cast<unsigned>(value));
}
// Emit the function to return the max enum value
os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName);
os << formatv(" return {0};\n", maxEnumVal);
os << "}\n\n";
}
// Returns the EnumAttrCase whose value is zero if exists; returns llvm::None
// otherwise.
static llvm::Optional<EnumAttrCase>
getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
for (auto attrCase : cases) {
if (attrCase.getValue() == 0)
return attrCase;
}
return llvm::None;
}
// Emits the following inline function for bit enums:
//
// inline <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline <enum-type> operator&(<enum-type> a, <enum-type> b);
// inline <enum-type> bitEnumContains(<enum-type> a, <enum-type> b);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
os << formatv("inline {0} operator|({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv("inline {0} operator&({0} lhs, {0} rhs) {{\n", enumName)
<< formatv(" return static_cast<{0}>("
"static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
enumName, underlyingType)
<< "}\n";
os << formatv(
"inline bool bitEnumContains({0} bits, {0} bit) {{\n"
" return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
enumName, underlyingType)
<< "}\n";
}
static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
auto enumerants = enumAttr.getAllCases();
os << formatv("{2} {1}({0} val) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << " switch (val) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto str = enumerant.getStr();
os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName,
makeIdentifier(symbol), str);
}
os << " }\n";
os << " return \"\";\n";
os << "}\n\n";
}
static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("{2} {1}({0} symbol) {{\n", enumName, symToStrFnName,
symToStrFnRetType);
os << formatv(" auto val = static_cast<{0}>(symbol);\n",
enumAttr.getUnderlyingType());
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
os << formatv(" if (val == 0) return \"{0}\";\n\n",
allBitsUnsetCase->getSymbol());
}
os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
"val &= ~{0}u; }\n",
val, enumerant.getStr());
}
// If we have unknown bit set, return an empty string to signal errors.
os << "\n if (val) return \"\";\n";
os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);
os << "}\n\n";
}
static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
enumName, strToSymFnName);
os << formatv(" return ::llvm::StringSwitch<::llvm::Optional<{0}>>(str)\n",
enumName);
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto str = enumerant.getStr();
os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str,
makeIdentifier(symbol));
}
os << " .Default(::llvm::None);\n";
os << "}\n";
}
static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef separator = enumDef.getValueAsString("separator");
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef str) {{\n",
enumName, strToSymFnName);
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
StringRef caseSymbol = allBitsUnsetCase->getSymbol();
os << formatv(" if (str == \"{1}\") return {0}::{2};\n\n", enumName,
caseSymbol, makeIdentifier(caseSymbol));
}
// Split the string to get symbols for all the bits.
os << " ::llvm::SmallVector<::llvm::StringRef, 2> symbols;\n";
os << formatv(" str.split(symbols, \"{0}\");\n\n", separator);
os << formatv(" {0} val = 0;\n", underlyingType);
os << " for (auto symbol : symbols) {\n";
// Convert each symbol to the bit ordinal and set the corresponding bit.
os << formatv(
" auto bit = llvm::StringSwitch<::llvm::Optional<{0}>>(symbol)\n",
underlyingType);
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
}
os.indent(6) << ".Default(::llvm::None);\n";
os << " if (bit) { val |= *bit; } else { return ::llvm::None; }\n";
os << " }\n";
os << formatv(" return static_cast<{0}>(val);\n", enumName);
os << "}\n\n";
}
static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
// Avoid generating the underlying value to symbol conversion function if
// there is an enumerant without explicit value.
if (llvm::any_of(enumerants, [](EnumAttrCase enumerant) {
return enumerant.getValue() < 0;
}))
return;
os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName,
underlyingType.empty() ? std::string("unsigned")
: underlyingType)
<< " switch (value) {\n";
for (const auto &enumerant : enumerants) {
auto symbol = enumerant.getSymbol();
auto value = enumerant.getValue();
os << formatv(" case {0}: return {1}::{2};\n", value, enumName,
makeIdentifier(symbol));
}
os << " default: return ::llvm::None;\n"
<< " }\n"
<< "}\n\n";
}
static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
Attribute baseAttr(baseAttrDef);
// Emit classof method
os << formatv("bool {0}::classof(::mlir::Attribute attr) {{\n",
attrClassName);
mlir::tblgen::Pred baseAttrPred = baseAttr.getPredicate();
if (baseAttrPred.isNull())
PrintFatalError("ERROR: baseAttrClass for EnumAttr has no Predicate\n");
std::string condition = baseAttrPred.getCondition();
FmtContext verifyCtx;
verifyCtx.withSelf("attr");
os << tgfmt(" return $0;\n", /*ctx=*/nullptr, tgfmt(condition, &verifyCtx));
os << "}\n";
// Emit get method
os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
attrClassName, enumName);
if (enumAttr.isSubClassOf("StrEnumAttr")) {
os << formatv(" ::mlir::StringAttr baseAttr = "
"::mlir::StringAttr::get(context, {0}(val));\n",
symToStrFnName);
} else {
StringRef underlyingType = enumAttr.getUnderlyingType();
// Assuming that it is IntegerAttr constraint
int64_t bitwidth = 64;
if (baseAttrDef->getValue("valueType")) {
auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
if (valueTypeDef->getValue("bitwidth"))
bitwidth = valueTypeDef->getValueAsInt("bitwidth");
}
os << formatv(" ::mlir::IntegerType intType = "
"::mlir::IntegerType::get(context, {0});\n",
bitwidth);
os << formatv(" ::mlir::IntegerAttr baseAttr = "
"::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
underlyingType);
}
os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName);
os << "}\n";
// Emit getValue method
os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
if (enumAttr.isSubClassOf("StrEnumAttr")) {
os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n",
strToSymFnName);
os << " return res.getValue();\n";
} else {
os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
enumName);
}
os << "}\n";
}
static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
auto allBitsUnsetCase = getAllBitsUnsetCase(enumerants);
os << formatv("::llvm::Optional<{0}> {1}({2} value) {{\n", enumName,
underlyingToSymFnName, underlyingType);
if (allBitsUnsetCase) {
os << " // Special case for all bits unset.\n";
os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
makeIdentifier(allBitsUnsetCase->getSymbol()));
}
llvm::SmallVector<std::string, 8> values;
for (const auto &enumerant : enumerants) {
if (auto val = enumerant.getValue())
values.push_back(std::string(formatv("{0}u", val)));
}
os << formatv(" if (value & ~({0})) return llvm::None;\n",
llvm::join(values, " | "));
os << formatv(" return static_cast<{0}>(value);\n", enumName);
os << "}\n";
}
static void emitEnumDecl(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
StringRef description = enumAttr.getSummary();
StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
StringRef symToStrFnRetType = enumAttr.getSymbolToStringFnRetType();
StringRef underlyingToSymFnName = enumAttr.getUnderlyingToSymbolFnName();
auto enumerants = enumAttr.getAllCases();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
// Emit the enum class definition
emitEnumClass(enumDef, enumName, underlyingType, description, enumerants, os);
// Emit conversion function declarations
if (llvm::all_of(enumerants, [](EnumAttrCase enumerant) {
return enumerant.getValue() >= 0;
})) {
os << formatv(
"::llvm::Optional<{0}> {1}({2});\n", enumName, underlyingToSymFnName,
underlyingType.empty() ? std::string("unsigned") : underlyingType);
}
os << formatv("{2} {1}({0});\n", enumName, symToStrFnName, symToStrFnRetType);
os << formatv("::llvm::Optional<{0}> {1}(::llvm::StringRef);\n", enumName,
strToSymFnName);
if (enumAttr.isBitEnum()) {
emitOperators(enumDef, os);
} else {
emitMaxValueFn(enumDef, os);
}
// Generate a generic `stringifyEnum` function that forwards to the method
// specified by the user.
const char *const stringifyEnumStr = R"(
inline {0} stringifyEnum({1} enumValue) {{
return {2}(enumValue);
}
)";
os << formatv(stringifyEnumStr, symToStrFnRetType, enumName, symToStrFnName);
// Generate a generic `symbolizeEnum` function that forwards to the method
// specified by the user.
const char *const symbolizeEnumStr = R"(
template <typename EnumType>
::llvm::Optional<EnumType> symbolizeEnum(::llvm::StringRef);
template <>
inline ::llvm::Optional<{0}> symbolizeEnum<{0}>(::llvm::StringRef str) {
return {1}(str);
}
)";
os << formatv(symbolizeEnumStr, enumName, strToSymFnName);
const char *const attrClassDecl = R"(
class {1} : public ::mlir::{2} {
public:
using ValueType = {0};
using ::mlir::{2}::{2};
static bool classof(::mlir::Attribute attr);
static {1} get(::mlir::MLIRContext *context, {0} val);
{0} getValue() const;
};
)";
if (enumAttr.genSpecializedAttr()) {
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
StringRef baseAttrClassName =
enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
}
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
// Emit DenseMapInfo for this enum class
emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
}
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Declarations", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDecl(*def, os);
return false;
}
static void emitEnumDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef cppNamespace = enumAttr.getCppNamespace();
llvm::SmallVector<StringRef, 2> namespaces;
llvm::SplitString(cppNamespace, namespaces, "::");
for (auto ns : namespaces)
os << "namespace " << ns << " {\n";
if (enumAttr.isBitEnum()) {
emitSymToStrFnForBitEnum(enumDef, os);
emitStrToSymFnForBitEnum(enumDef, os);
emitUnderlyingToSymFnForBitEnum(enumDef, os);
} else {
emitSymToStrFnForIntEnum(enumDef, os);
emitStrToSymFnForIntEnum(enumDef, os);
emitUnderlyingToSymFnForIntEnum(enumDef, os);
}
if (enumAttr.genSpecializedAttr())
emitSpecializedAttrDef(enumDef, os);
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
os << "\n";
}
static bool emitEnumDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
llvm::emitSourceFileHeader("Enum Utility Definitions", os);
auto defs = recordKeeper.getAllDerivedDefinitions("EnumAttrInfo");
for (const auto *def : defs)
emitEnumDef(*def, os);
return false;
}
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genEnumDecls("gen-enum-decls", "Generate enum utility declarations",
[](const RecordKeeper &records, raw_ostream &os) {
return emitEnumDecls(records, os);
});
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genEnumDefs("gen-enum-defs", "Generate enum utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitEnumDefs(records, os);
});