blob: 814008c25451143039cd4cd0aaf97a8b735f3ae0 [file] [log] [blame]
//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
// binding classes wrapping a generic operation API.
#include "OpGenHelpers.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
/// File header and includes.
/// {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
_ods_ir =
import builtins
from typing import Sequence as _Sequence, Union as _Union
/// Template for dialect class:
/// {0} is the dialect namespace.
constexpr const char *dialectClassTemplate = R"Py(
class _Dialect(_ods_ir.Dialect):
constexpr const char *dialectExtensionTemplate = R"Py(
from ._{0}_ops_gen import _Dialect
/// Template for operation class:
/// {0} is the Python class name;
/// {1} is the operation name.
constexpr const char *opClassTemplate = R"Py(
class {0}(_ods_ir.OpView):
/// Template for class level declarations of operand and result
/// segment specs.
/// {0} is either "OPERAND" or "RESULT"
/// {1} is the segment spec
/// Each segment spec is either None (default) or an array of integers
/// where:
/// 1 = single element (expect non sequence operand/result)
/// 0 = optional element (expect a value or std::nullopt)
/// -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
_ODS_{0}_SEGMENTS = {1}
/// Template for class level declarations of the _ODS_REGIONS spec:
/// {0} is the minimum number of regions
/// {1} is the Python bool literal for hasNoVariadicRegions
constexpr const char *opClassRegionSpecTemplate = R"Py(
_ODS_REGIONS = ({0}, {1})
/// Template for single-element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position in the element list.
constexpr const char *opSingleTemplate = R"Py(
def {0}(self):
return self.operation.{1}s[{2}]
/// Template for single-element accessor after a variable-length group:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
/// Template for an optional element accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
def {0}(self):
return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
/// Template for the variadic group accessor in the single variadic group case:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of element groups;
/// {3} is the position of the current group in the group list.
constexpr const char *opOneVariadicTemplate = R"Py(
def {0}(self):
_ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
/// First part of the template for equally-sized variadic group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the total number of variadic groups;
/// {3} is the number of non-variadic groups preceding the current group;
/// {3} is the number of variadic groups preceding the current group.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
def {0}(self):
start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
/// Second part of the template for equally-sized case, accessing a single
/// element:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
return self.operation.{0}s[start]
/// Second part of the template for equally-sized case, accessing a variadic
/// group:
/// {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
return self.operation.{0}s[start:start + pg]
/// Template for an attribute-sized group accessor:
/// {0} is the name of the accessor;
/// {1} is either 'operand' or 'result';
/// {2} is the position of the group in the group list;
/// {3} is a return suffix (expected [0] for single-element, empty for
/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
constexpr const char *opVariadicSegmentTemplate = R"Py(
def {0}(self):
{1}_range = _ods_segmented_accessor(
self.operation.attributes["{1}SegmentSizes"], {2})
return {1}_range{3}
/// Template for a suffix when accessing an optional element in the
/// attribute-sized case:
/// {0} is either 'operand' or 'result';
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
R"Py([0] if len({0}_range) > 0 else None)Py";
/// Template for an operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeGetterTemplate = R"Py(
def {0}(self):
return self.operation.attributes["{1}"]
/// Template for an optional operation attribute getter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
def {0}(self):
if "{1}" not in self.operation.attributes:
return None
return self.operation.attributes["{1}"]
/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
/// {0} is the name of the attribute sanitized for Python,
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
def {0}(self):
return "{1}" in self.operation.attributes
/// Template for an operation attribute setter:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeSetterTemplate = R"Py(
def {0}(self, value):
if value is None:
raise ValueError("'None' not allowed as value for mandatory attributes")
self.operation.attributes["{1}"] = value
/// Template for a setter of an optional operation attribute, setting to None
/// removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
def {0}(self, value):
if value is not None:
self.operation.attributes["{1}"] = value
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
/// Template for a setter of a unit operation attribute, setting to None or
/// False removes the attribute:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *unitAttributeSetterTemplate = R"Py(
def {0}(self, value):
if bool(value):
self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
elif "{1}" in self.operation.attributes:
del self.operation.attributes["{1}"]
/// Template for a deleter of an optional or a unit operation attribute, removes
/// the attribute from the operation:
/// {0} is the name of the attribute sanitized for Python;
/// {1} is the original name of the attribute.
constexpr const char *attributeDeleterTemplate = R"Py(
def {0}(self):
del self.operation.attributes["{1}"]
constexpr const char *regionAccessorTemplate = R"Py(
def {0}(self):
return self.regions[{1}]
constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
return _get_op_result_or_op_results({1}({3}))
static llvm::cl::OptionCategory
clOpPythonBindingCat("Options for -gen-python-op-bindings");
static llvm::cl::opt<std::string>
llvm::cl::desc("The dialect to run the generator for"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
static llvm::cl::opt<std::string> clDialectExtensionName(
"dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
using AttributeClasses = DenseMap<StringRef, StringRef>;
/// Checks whether `str` would shadow a generated variable or attribute
/// part of the OpView API.
static bool isODSReserved(StringRef str) {
static llvm::StringSet<> reserved(
{"attributes", "create", "context", "ip", "operands", "print", "get_asm",
"loc", "verify", "regions", "results", "self", "operation",
return str.starts_with("_ods_") || str.ends_with("_ods") ||
/// Modifies the `name` in a way that it becomes suitable for Python bindings
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
std::string processedStr = name.str();
processedStr.begin(), processedStr.end(),
[](char c) { return !llvm::isAlnum(c); }, '_');
if (llvm::isDigit(*processedStr.begin()))
return "_" + processedStr;
if (isPythonReserved(processedStr) || isODSReserved(processedStr))
return processedStr + "_";
return processedStr;
static std::string attrSizedTraitForKind(const char *kind) {
return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
const Operator &op, raw_ostream &os, const char *kind,
llvm::function_ref<unsigned(const Operator &)> getNumVariableLength,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
getElement) {
llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
"unsupported kind");
// Traits indicating how to process variadic elements.
std::string sameSizeTrait =
std::string attrSizedTrait = attrSizedTraitForKind(kind);
unsigned numVariableLength = getNumVariableLength(op);
// If there is only one variable-length element group, its size can be
// inferred from the total number of elements. If there are none, the
// generation is straightforward.
if (numVariableLength <= 1) {
bool seenVariableLength = false;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isVariableLength())
seenVariableLength = true;
if (
if (element.isVariableLength()) {
os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
: opOneVariadicTemplate,
sanitizeName(, kind,
getNumElements(op), i);
} else if (seenVariableLength) {
os << llvm::formatv(opSingleAfterVariableTemplate,
sanitizeName(, kind,
getNumElements(op), i);
} else {
os << llvm::formatv(opSingleTemplate, sanitizeName(, kind,
// Handle the operations where variadic groups have the same size.
if (op.getTrait(sameSizeTrait)) {
int numPrecedingSimple = 0;
int numPrecedingVariadic = 0;
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (! {
os << llvm::formatv(opVariadicEqualPrefixTemplate,
sanitizeName(, kind, numVariableLength,
numPrecedingSimple, numPrecedingVariadic);
os << llvm::formatv(element.isVariableLength()
? opVariadicEqualVariadicTemplate
: opVariadicEqualSimpleTemplate,
if (element.isVariableLength())
// Handle the operations where the size of groups (variadic or not) is
// provided as an attribute. For non-variadic elements, make sure to return
// an element rather than a singleton container.
if (op.getTrait(attrSizedTrait)) {
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (
std::string trailing;
if (!element.isVariableLength())
trailing = "[0]";
else if (element.isOptional())
trailing = std::string(
llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(,
kind, i, trailing);
llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
/// Free function helpers accessing Operator components.
static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
return op.getOperand(i);
static int getNumResults(const Operator &op) { return op.getNumResults(); }
static const NamedTypeConstraint &getResult(const Operator &op, int i) {
return op.getResult(i);
/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariableLengthOperands = [](const Operator &oper) {
return oper.getNumVariableLengthOperands();
emitElementAccessors(op, os, "operand", getNumVariableLengthOperands,
getNumOperands, getOperand);
/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
auto getNumVariableLengthResults = [](const Operator &oper) {
return oper.getNumVariableLengthResults();
emitElementAccessors(op, os, "result", getNumVariableLengthResults,
getNumResults, getResult);
/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
for (const auto &namedAttr : op.getAttributes()) {
// Skip "derived" attributes because they are just C++ functions that we
// don't currently expose.
if (namedAttr.attr.isDerivedAttr())
if (
std::string sanitizedName = sanitizeName(;
// Unit attributes are handled specially.
if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName,;
os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName,;
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,;
if (namedAttr.attr.isOptional()) {
os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName,;
os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName,;
os << llvm::formatv(attributeDeleterTemplate, sanitizedName,;
} else {
os << llvm::formatv(attributeGetterTemplate, sanitizedName,;
os << llvm::formatv(attributeSetterTemplate, sanitizedName,;
// Non-optional attributes cannot be deleted.
/// Template for the default auto-generated builder.
/// {0} is a comma-separated list of builder arguments, including the trailing
/// `loc` and `ip`;
/// {1} is the code populating `operands`, `results` and `attributes`,
/// `successors` fields.
constexpr const char *initTemplate = R"Py(
def __init__(self, {0}):
operands = []
results = []
attributes = {{}
regions = None
/// Template for appending a single element to the operand/result list.
/// {0} is the field name.
constexpr const char *singleOperandAppendTemplate =
constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
"if {0} is not None: operands.append(_get_op_result_or_value({0}))";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
"operands.append(_get_op_result_or_value({0}) if {0} is not None else "
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
/// Template for appending a list of elements to the operand/result list.
/// {0} is the field name.
constexpr const char *multiOperandAppendTemplate =
constexpr const char *multiOperandAppendPackTemplate =
constexpr const char *multiResultAppendTemplate = "results.extend({0})";
/// Template for attribute builder from raw input in the operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
R"Py(attributes["{1}"] = ({0} if (
isinstance({0}, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
/// {0} is the builder argument name;
/// {1} is the attribute builder from raw;
/// {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
isinstance({0}, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('{2}')) else
_ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";
constexpr const char *initUnitAttributeTemplate =
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
/// Template to initialize the successors list in the builder if there are any
/// successors.
/// {0} is the value to initialize the successors list to.
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";
/// Template to append or extend the list of successors in the builder.
/// {0} is the list method ('append' or 'extend');
/// {1} is the value to add.
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";
/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
/// result types of the given operation.
static bool hasSameArgumentAndResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
op.getNumVariableLengthResults() == 0;
/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
/// result types of the given operation.
static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
op.getNumVariableLengthResults() == 0;
/// Returns true if the InferTypeOpInterface can be used to infer result types
/// of the given operation.
static bool hasInferTypeInterface(const Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
op.getNumRegions() == 0;
/// Returns true if there is a trait or interface that can be used to infer
/// result types of the given operation.
static bool canInferType(const Operator &op) {
return hasSameArgumentAndResultTypes(op) ||
hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
/// Populates `builderArgs` with result names if the builder is expected to
/// accept them as arguments.
static void
populateBuilderArgsResults(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs) {
if (canInferType(op))
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
std::string name = op.getResultName(i).str();
if (name.empty()) {
if (op.getNumResults() == 1) {
// Special case for one result, make the default name be 'result'
// to properly match the built-in result accessor.
name = "result";
} else {
name = llvm::formatv("_gen_res_{0}", i);
name = sanitizeName(name);
/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments using intermixed attributes and operands in the same order as they
/// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of
/// appearance.
static void
populateBuilderArgs(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &operandNames) {
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
std::string name = op.getArgName(i).str();
if (name.empty())
name = llvm::formatv("_gen_arg_{0}", i);
name = sanitizeName(name);
if (!op.getArg(i).is<NamedAttribute *>())
/// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated.
static void populateBuilderArgsSuccessors(
const Operator &op, llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &successorArgNames) {
for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
NamedSuccessor successor = op.getSuccessor(i);
std::string name = std::string(;
if (name.empty())
name = llvm::formatv("_gen_successor_{0}", i);
name = sanitizeName(name);
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up operation attributes. `argNames` is expected to contain
/// the names of builder arguments that correspond to op arguments, i.e. to the
/// operands and attributes in the same order as they appear in the `arguments`
/// field.
static void
populateBuilderLinesAttr(const Operator &op,
llvm::ArrayRef<std::string> argNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
Argument arg = op.getArg(i);
auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
if (!attribute)
// Unit attributes are handled specially.
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) {
attribute->name, argNames[i]));
attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
? initOptionalAttributeWithBuilderTemplate
: initAttributeWithBuilderTemplate,
argNames[i], attribute->name, attribute->attr.getAttrDefName()));
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op.
static void populateBuilderLinesSuccessors(
const Operator &op, llvm::ArrayRef<std::string> successorArgNames,
llvm::SmallVectorImpl<std::string> &builderLines) {
if (successorArgNames.empty()) {
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None"));
builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]"));
for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
auto &argName = successorArgNames[i];
const NamedSuccessor &successor = op.getSuccessor(i);
successor.isVariadic() ? "extend" : "append", argName));
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;
// For each element, find or generate a name.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
const NamedTypeConstraint &element = op.getOperand(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleOperandAppendTemplate;
} else if (element.isOptional()) {
if (sizedSegments) {
formatString = optionalAppendAttrSizedOperandsTemplate;
} else {
formatString = optionalAppendOperandTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = multiOperandAppendPackTemplate;
} else {
formatString = multiOperandAppendTemplate;
builderLines.push_back(llvm::formatv(, name));
/// Python code template for deriving the operation result types from its
/// attribute:
/// - {0} is the name of the attribute from which to derive the types.
constexpr const char *deriveTypeFromAttrTemplate =
R"Py(_ods_result_type_source_attr = attributes["{0}"]
_ods_derived_result_type = (
if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
/// Python code template appending {0} type {1} times to the results list.
constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
llvm::SmallVectorImpl<std::string> &builderLines) {
std::pair<StringRef, StringRef> split = std::make_pair(string, string);
do {
split = split.second.split('\n');
} while (!split.second.empty());
/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op,
llvm::ArrayRef<std::string> names,
llvm::SmallVectorImpl<std::string> &builderLines) {
bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
if (hasSameArgumentAndResultTypes(op)) {
appendSameResultsTemplate, "operands[0].type", op.getNumResults()));
if (hasFirstAttrDerivedResultTypes(op)) {
const NamedAttribute &firstAttr = op.getAttribute(0);
assert(! && "unexpected empty name for the attribute "
"from which the type is derived");
if (hasInferTypeInterface(op))
// For each element, find or generate a name.
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
const NamedTypeConstraint &element = op.getResult(i);
std::string name = names[i];
// Choose the formatting string based on the element kind.
llvm::StringRef formatString;
if (!element.isVariableLength()) {
formatString = singleResultAppendTemplate;
} else if (element.isOptional()) {
formatString = optionalAppendResultTemplate;
} else {
assert(element.isVariadic() && "unhandled element group type");
// If emitting with sizedSegments, then we add the actual list-typed
// element. Otherwise, we extend the actual operands.
if (sizedSegments) {
formatString = singleResultAppendTemplate;
} else {
formatString = multiResultAppendTemplate;
builderLines.push_back(llvm::formatv(, name));
/// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic
/// constructor.
static void
populateBuilderRegions(const Operator &op,
llvm::SmallVectorImpl<std::string> &builderArgs,
llvm::SmallVectorImpl<std::string> &builderLines) {
if (op.hasNoVariadicRegions())
// This is currently enforced when Operator is constructed.
assert(op.getNumVariadicRegions() == 1 &&
op.getRegion(op.getNumRegions() - 1).isVariadic() &&
"expected the last region to be varidic");
const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
std::string name =
("num_" + +
llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands. Returns vector
/// of fully built functionArgs for downstream users (to save having to
/// rebuild anew).
static llvm::SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
raw_ostream &os) {
llvm::SmallVector<std::string> builderArgs;
llvm::SmallVector<std::string> builderLines;
llvm::SmallVector<std::string> operandArgNames;
llvm::SmallVector<std::string> successorArgNames;
builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
op.getNumNativeAttributes() + op.getNumSuccessors());
populateBuilderArgsResults(op, builderArgs);
size_t numResultArgs = builderArgs.size();
populateBuilderArgs(op, builderArgs, operandArgNames);
size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);
populateBuilderLinesOperand(op, operandArgNames, builderLines);
op, llvm::ArrayRef(builderArgs).drop_front(numResultArgs), builderLines);
op, llvm::ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
populateBuilderRegions(op, builderArgs, builderLines);
// Layout of builderArgs vector elements:
// [ result_args operand_attr_args successor_args regions ]
// Determine whether the argument corresponding to a given index into the
// builderArgs vector is a python keyword argument or not.
auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
// All result, successor, and region arguments are positional arguments.
if ((builderArgIndex < numResultArgs) ||
(builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
return false;
// Keyword arguments:
// - optional named attributes (including unit attributes)
// - default-valued named attributes
// - optional operands
Argument a = op.getArg(builderArgIndex - numResultArgs);
if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
return ntype->isOptional();
return false;
// StringRefs in functionArgs refer to strings allocated by builderArgs.
llvm::SmallVector<llvm::StringRef> functionArgs;
// Add positional arguments.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (!isKeywordArgFn(i))
// Add a bare '*' to indicate that all following arguments must be keyword
// arguments.
// Add a default 'None' value to each keyword arg string, and then add to the
// function args list.
for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
if (isKeywordArgFn(i)) {
SmallVector<std::string> initArgs;
if (!hasInferTypeInterface(op))
os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
llvm::join(builderLines, "\n "),
llvm::join(initArgs, ", "));
return llvm::to_vector<8>(
llvm::map_range(functionArgs, [](llvm::StringRef s) { return s.str(); }));
static void emitSegmentSpec(
const Operator &op, const char *kind,
llvm::function_ref<int(const Operator &)> getNumElements,
llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
raw_ostream &os) {
std::string segmentSpec("[");
for (int i = 0, e = getNumElements(op); i < e; ++i) {
const NamedTypeConstraint &element = getElement(op, i);
if (element.isOptional()) {
} else if (element.isVariadic()) {
} else {
os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
// Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
// Note that the base OpView class defines this as (0, True).
unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount,
op.hasNoVariadicRegions() ? "True" : "False");
/// Emits named accessors to regions.
static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
for (const auto &en : llvm::enumerate(op.getRegions())) {
const NamedRegion &region = en.value();
if (
assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
"expected only the last region to be variadic");
os << llvm::formatv(regionAccessorTemplate, sanitizeName(,
std::to_string(en.index()) +
(region.isVariadic() ? ":" : ""));
/// Emits builder that extracts results from op
static void emitValueBuilder(const Operator &op,
llvm::SmallVector<std::string> functionArgs,
raw_ostream &os) {
// Params with (possibly) default args.
auto valueBuilderParams =
llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
llvm::SmallVector<llvm::StringRef> argMaybeDefault =
llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
if (argMaybeDefault.size() == 2)
return arg + "=" + argMaybeDefault[1].str();
return arg;
// Actual args passed to op builder (e.g., opParam=op_param).
auto opBuilderArgs = llvm::map_range(
[](const std::string &s) { return s != "*"; }),
[](const std::string &arg) {
auto lhs = *llvm::split(arg, "=").begin();
return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << llvm::formatv(
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
? "_Sequence[_ods_ir.Value]"
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
os << llvm::formatv(opClassTemplate, op.getCppClassName(),
// Sized segments.
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
emitRegionAttributes(op, os);
llvm::SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
emitOperandAccessors(op, os);
emitAttributeAccessors(op, os);
emitResultAccessors(op, os);
emitRegionAccessors(op, os);
emitValueBuilder(op, functionArgs, os);
/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
if (clDialectName.empty())
llvm::PrintFatalError("dialect name not provided");
os << fileHeader;
if (!clDialectExtensionName.empty())
os << llvm::formatv(dialectExtensionTemplate, clDialectName.getValue());
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
emitOpBindings(op, os);
return false;
static GenRegistration
"Generate Python bindings for MLIR Ops", &emitAllOps);