blob: 2c7acec3b1b8538a1abf039c2be3c3623f34594d [file] [log] [blame]
//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file uses tablegen definitions of the LLVM IR Dialect operations to
// generate the code building the LLVM IR from it.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
static LogicalResult emitError(const Record &record, const Twine &message) {
PrintError(&record, message);
return failure();
}
namespace {
// Helper structure to return a position of the substring in a string.
struct StringLoc {
size_t pos;
size_t length;
// Take a substring identified by this location in the given string.
StringRef in(StringRef str) const { return str.substr(pos, length); }
// A location is invalid if its position is outside the string.
explicit operator bool() { return pos != std::string::npos; }
};
} // namespace
// Find the next TableGen variable in the given pattern. These variables start
// with a `$` character and can contain alphanumeric characters or underscores.
// Return the position of the variable in the pattern and its length, including
// the `$` character. The escape syntax `$$` is also detected and returned.
static StringLoc findNextVariable(StringRef str) {
size_t startPos = str.find('$');
if (startPos == std::string::npos)
return {startPos, 0};
// If we see "$$", return immediately.
if (startPos != str.size() - 1 && str[startPos + 1] == '$')
return {startPos, 2};
// Otherwise, the symbol spans until the first character that is not
// alphanumeric or '_'.
size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
startPos + 1);
if (endPos == std::string::npos)
endPos = str.size();
return {startPos, endPos - startPos};
}
// Check if `name` is a variadic operand of `op`. Seach all operands since the
// MLIR and LLVM IR operand order may differ and only for the latter the
// variadic operand is guaranteed to be at the end of the operands list.
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i)
if (op.getOperand(i).name == name)
return op.getOperand(i).isVariadic();
return false;
}
// Check if `result` is a known name of a result of `op`.
static bool isResultName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumResults(); i < e; ++i)
if (op.getResultName(i) == name)
return true;
return false;
}
// Check if `name` is a known name of an attribute of `op`.
static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
return llvm::any_of(
op.getAttributes(),
[name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
}
// Check if `name` is a known name of an operand of `op`.
static bool isOperandName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i)
if (op.getOperand(i).name == name)
return true;
return false;
}
// Return the `op` argument index of the argument with the given `name`.
static FailureOr<int> getArgumentIndex(const tblgen::Operator &op,
StringRef name) {
for (int i = 0, e = op.getNumArgs(); i != e; ++i)
if (op.getArgName(i) == name)
return i;
return failure();
}
// Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
// for one definition of an LLVM IR Dialect operation.
static LogicalResult emitOneBuilder(const Record &record, raw_ostream &os) {
auto op = tblgen::Operator(record);
if (!record.getValue("llvmBuilder"))
return emitError(record, "expected 'llvmBuilder' field");
// Return early if there is no builder specified.
StringRef builderStrRef = record.getValueAsString("llvmBuilder");
if (builderStrRef.empty())
return success();
// Progressively create the builder string by replacing $-variables with
// value lookups. Keep only the not-yet-traversed part of the builder pattern
// to avoid re-traversing the string multiple times.
std::string builder;
llvm::raw_string_ostream bs(builder);
while (StringLoc loc = findNextVariable(builderStrRef)) {
auto name = loc.in(builderStrRef).drop_front();
auto getterName = op.getGetterName(name);
// First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind.
bool isVariadicOperand = isVariadicOperandName(op, name);
if (isOperandName(op, name)) {
auto result =
isVariadicOperand
? formatv("moduleTranslation.lookupValues(op.{0}())", getterName)
: formatv("moduleTranslation.lookupValue(op.{0}())", getterName);
bs << result;
} else if (isAttributeName(op, name)) {
bs << formatv("op.{0}()", getterName);
} else if (isResultName(op, name)) {
bs << formatv("moduleTranslation.mapValue(op.{0}())", getterName);
} else if (name == "_resultType") {
bs << "moduleTranslation.convertType(op.getResult().getType())";
} else if (name == "_hasResult") {
bs << "opInst.getNumResults() == 1";
} else if (name == "_location") {
bs << "opInst.getLoc()";
} else if (name == "_numOperands") {
bs << "opInst.getNumOperands()";
} else if (name == "$") {
bs << '$';
} else {
return emitError(
record, "expected keyword, argument, or result, but got " + name);
}
// Finally, only keep the untraversed part of the string.
builderStrRef = builderStrRef.substr(loc.pos + loc.length);
}
// Output the check and the rewritten builder string.
os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
<< ">(opInst)) {\n";
os << bs.str() << builderStrRef << "\n";
os << " return success();\n";
os << "}\n";
return success();
}
// Emit all builders. Returns false on success because of the generator
// registration requirements.
static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
if (failed(emitOneBuilder(*def, os)))
return true;
}
return false;
}
using ConditionFn = mlir::function_ref<llvm::Twine(const Record &record)>;
// Emit a conditional call to the MLIR builder of the LLVM dialect operation to
// build for the given LLVM IR instruction. A condition function `conditionFn`
// emits a check to verify the opcode or intrinsic identifier of the LLVM IR
// instruction matches the LLVM dialect operation to build.
static LogicalResult emitOneMLIRBuilder(const Record &record, raw_ostream &os,
ConditionFn conditionFn) {
auto op = tblgen::Operator(record);
if (!record.getValue("mlirBuilder"))
return emitError(record, "expected 'mlirBuilder' field");
// Return early if there is no builder specified.
StringRef builderStrRef = record.getValueAsString("mlirBuilder");
if (builderStrRef.empty())
return success();
// Access the argument index array that maps argument indices to LLVM IR
// operand indices. If the operation defines no custom mapping, set the array
// to the identity permutation.
std::vector<int64_t> llvmArgIndices =
record.getValueAsListOfInts("llvmArgIndices");
if (llvmArgIndices.empty())
append_range(llvmArgIndices, seq<int64_t>(0, op.getNumArgs()));
if (llvmArgIndices.size() != static_cast<size_t>(op.getNumArgs())) {
return emitError(
record,
"expected 'llvmArgIndices' size to match the number of arguments");
}
// Progressively create the builder string by replacing $-variables. Keep only
// the not-yet-traversed part of the builder pattern to avoid re-traversing
// the string multiple times. Additionally, emit an argument string
// immediately before the builder string. This argument string converts all
// operands used by the builder to MLIR values and returns failure if one of
// the conversions fails.
std::string arguments, builder;
llvm::raw_string_ostream as(arguments), bs(builder);
while (StringLoc loc = findNextVariable(builderStrRef)) {
auto name = loc.in(builderStrRef).drop_front();
// First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind.
FailureOr<int> argIndex = getArgumentIndex(op, name);
if (succeeded(argIndex)) {
// Access the LLVM IR operand that maps to the given argument index using
// the provided argument indices mapping.
int64_t idx = llvmArgIndices[*argIndex];
if (idx < 0) {
return emitError(
record, "expected non-negative operand index for argument " + name);
}
if (isAttributeName(op, name)) {
bs << formatv("llvmOperands[{0}]", idx);
} else {
if (isVariadicOperandName(op, name)) {
as << formatv(
"FailureOr<SmallVector<Value>> _llvmir_gen_operand_{0} = "
"moduleImport.convertValues(llvmOperands.drop_front({1}));\n",
name, idx);
} else {
as << formatv("FailureOr<Value> _llvmir_gen_operand_{0} = "
"moduleImport.convertValue(llvmOperands[{1}]);\n",
name, idx);
}
as << formatv("if (failed(_llvmir_gen_operand_{0}))\n"
" return failure();\n",
name);
bs << formatv("*_llvmir_gen_operand_{0}", name);
}
} else if (isResultName(op, name)) {
if (op.getNumResults() != 1)
return emitError(record, "expected op to have one result");
bs << "moduleImport.mapValue(inst)";
} else if (name == "_op") {
bs << "moduleImport.mapNoResultOp(inst)";
} else if (name == "_int_attr") {
bs << "moduleImport.matchIntegerAttr";
} else if (name == "_float_attr") {
bs << "moduleImport.matchFloatAttr";
} else if (name == "_var_attr") {
bs << "moduleImport.matchLocalVariableAttr";
} else if (name == "_label_attr") {
bs << "moduleImport.matchLabelAttr";
} else if (name == "_fpExceptionBehavior_attr") {
bs << "moduleImport.matchFPExceptionBehaviorAttr";
} else if (name == "_roundingMode_attr") {
bs << "moduleImport.matchRoundingModeAttr";
} else if (name == "_resultType") {
bs << "moduleImport.convertType(inst->getType())";
} else if (name == "_location") {
bs << "moduleImport.translateLoc(inst->getDebugLoc())";
} else if (name == "_builder") {
bs << "odsBuilder";
} else if (name == "_qualCppClassName") {
bs << op.getQualCppClassName();
} else if (name == "$") {
bs << '$';
} else {
return emitError(
record, "expected keyword, argument, or result, but got " + name);
}
// Finally, only keep the untraversed part of the string.
builderStrRef = builderStrRef.substr(loc.pos + loc.length);
}
// Output the check, the argument conversion, and the builder string.
os << "if (" << conditionFn(record) << ") {\n";
os << as.str() << "\n";
os << bs.str() << builderStrRef << "\n";
os << " return success();\n";
os << "}\n";
return success();
}
// Emit all intrinsic MLIR builders. Returns false on success because of the
// generator registration requirements.
static bool emitIntrMLIRBuilders(const RecordKeeper &recordKeeper,
raw_ostream &os) {
// Emit condition to check if "llvmEnumName" matches the intrinsic id.
auto emitIntrCond = [](const Record &record) {
return "intrinsicID == llvm::Intrinsic::" +
record.getValueAsString("llvmEnumName");
};
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_IntrOpBase")) {
if (failed(emitOneMLIRBuilder(*def, os, emitIntrCond)))
return true;
}
return false;
}
// Emit all op builders. Returns false on success because of the
// generator registration requirements.
static bool emitOpMLIRBuilders(const RecordKeeper &recordKeeper,
raw_ostream &os) {
// Emit condition to check if "llvmInstName" matches the instruction opcode.
auto emitOpcodeCond = [](const Record &record) {
return "inst->getOpcode() == llvm::Instruction::" +
record.getValueAsString("llvmInstName");
};
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
if (failed(emitOneMLIRBuilder(*def, os, emitOpcodeCond)))
return true;
}
return false;
}
namespace {
// Wrapper class around a Tablegen definition of an LLVM enum attribute case.
class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
public:
using tblgen::EnumAttrCase::EnumAttrCase;
// Constructs a case from a non LLVM-specific enum attribute case.
explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
: tblgen::EnumAttrCase(&other.getDef()) {}
// Returns the C++ enumerant for the LLVM API.
StringRef getLLVMEnumerant() const {
return def->getValueAsString("llvmEnumerant");
}
};
// Wraper class around a Tablegen definition of an LLVM enum attribute.
class LLVMEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.emplace_back(c);
return cases;
}
std::vector<LLVMEnumAttrCase> getAllUnsupportedCases() const {
const auto *inits = def->getValueAsListInit("unsupported");
std::vector<LLVMEnumAttrCase> cases;
cases.reserve(inits->size());
for (const llvm::Init *init : *inits)
cases.emplace_back(cast<llvm::DefInit>(init));
return cases;
}
};
// Wraper class around a Tablegen definition of a C-style LLVM enum attribute.
class LLVMCEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.emplace_back(c);
return cases;
}
};
} // namespace
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv(
"static LLVM_ATTRIBUTE_UNUSED {0} convert{1}ToLLVM({2}::{1} value) {{\n",
llvmClass, cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API C-style enumerant
static void emitOneCEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv("static LLVM_ATTRIBUTE_UNUSED int64_t "
"convert{0}ToLLVM({1}::{0} value) {{\n",
cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return static_cast<int64_t>({0}::{1});\n", llvmClass,
llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv("inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM({2} "
"value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
for (const auto &enumerant : enumAttr.getAllUnsupportedCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
os << formatv(" llvm_unreachable(\"unsupported case {0}::{1}\");\n",
enumAttr.getLLVMClassName(), llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMEnum)" and
// containing switch-based logic to convert from the LLVM API C-style enumerant
// to MLIR LLVM dialect enum attribute (Enum).
static void emitOneCEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMCEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv(
"inline LLVM_ATTRIBUTE_UNUSED {0}::{1} convert{1}FromLLVM(int64_t "
"value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case static_cast<int64_t>({0}::{1}):\n", llvmClass,
llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion functions between MLIR enum attribute case and corresponding
// LLVM API enumerants for all registered LLVM dialect enum attributes.
template <bool ConvertTo>
static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
if (ConvertTo)
emitOneEnumToConversion(def, os);
else
emitOneEnumFromConversion(def, os);
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_CEnumAttr"))
if (ConvertTo)
emitOneCEnumToConversion(def, os);
else
emitOneCEnumFromConversion(def, os);
return false;
}
static void emitOneIntrinsic(const Record &record, raw_ostream &os) {
auto op = tblgen::Operator(record);
os << "llvm::Intrinsic::" << record.getValueAsString("llvmEnumName") << ",\n";
}
// Emit the list of LLVM IR intrinsics identifiers that are convertible to a
// matching MLIR LLVM dialect intrinsic operation.
static bool emitConvertibleIntrinsics(const RecordKeeper &recordKeeper,
raw_ostream &os) {
for (const Record *def :
recordKeeper.getAllDerivedDefinitions("LLVM_IntrOpBase"))
emitOneIntrinsic(*def, os);
return false;
}
static mlir::GenRegistration
genLLVMIRConversions("gen-llvmir-conversions",
"Generate LLVM IR conversions", emitBuilders);
static mlir::GenRegistration genOpFromLLVMIRConversions(
"gen-op-from-llvmir-conversions",
"Generate conversions of operations from LLVM IR", emitOpMLIRBuilders);
static mlir::GenRegistration genIntrFromLLVMIRConversions(
"gen-intr-from-llvmir-conversions",
"Generate conversions of intrinsics from LLVM IR", emitIntrMLIRBuilders);
static mlir::GenRegistration
genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
"Generate conversions of EnumAttrs to LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/true>);
static mlir::GenRegistration
genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
"Generate conversions of EnumAttrs from LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/false>);
static mlir::GenRegistration genConvertibleLLVMIRIntrinsics(
"gen-convertible-llvmir-intrinsics",
"Generate list of convertible LLVM IR intrinsics",
emitConvertibleIntrinsics);