| //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
| // binding classes wrapping a generic operation API. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| /// File header and includes. |
| /// {0} is the dialect namespace. |
| constexpr const char *fileHeader = R"Py( |
| # Autogenerated by mlir-tblgen; don't manually edit. |
| |
| from . import _cext as _ods_cext |
| from . import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context |
| _ods_ir = _ods_cext.ir |
| |
| try: |
| from . import _{0} as _ods_ext_module |
| except ImportError: |
| _ods_ext_module = None |
| |
| )Py"; |
| |
| /// Template for dialect class: |
| /// {0} is the dialect namespace. |
| constexpr const char *dialectClassTemplate = R"Py( |
| @_ods_cext.register_dialect |
| class _Dialect(_ods_ir.Dialect): |
| DIALECT_NAMESPACE = "{0}" |
| pass |
| |
| )Py"; |
| |
| /// Template for operation class: |
| /// {0} is the Python class name; |
| /// {1} is the operation name. |
| constexpr const char *opClassTemplate = R"Py( |
| @_ods_cext.register_operation(_Dialect) |
| @_ods_extend_opview_class(_ods_ext_module) |
| class {0}(_ods_ir.OpView): |
| OPERATION_NAME = "{1}" |
| )Py"; |
| |
| /// Template for class level declarations of operand and result |
| /// segment specs. |
| /// {0} is either "OPERAND" or "RESULT" |
| /// {1} is the segment spec |
| /// Each segment spec is either None (default) or an array of integers |
| /// where: |
| /// 1 = single element (expect non sequence operand/result) |
| /// -1 = operand/result is a sequence corresponding to a variadic |
| constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
| _ODS_{0}_SEGMENTS = {1} |
| )Py"; |
| |
| /// Template for class level declarations of the _ODS_REGIONS spec: |
| /// {0} is the minimum number of regions |
| /// {1} is the Python bool literal for hasNoVariadicRegions |
| constexpr const char *opClassRegionSpecTemplate = R"Py( |
| _ODS_REGIONS = ({0}, {1}) |
| )Py"; |
| |
| /// Template for single-element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position in the element list. |
| constexpr const char *opSingleTemplate = R"Py( |
| @property |
| def {0}(self): |
| return self.operation.{1}s[{2}] |
| )Py"; |
| |
| /// Template for single-element accessor after a variable-length group: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| /// This works for both a single variadic group (non-negative length) and an |
| /// single optional element (zero length if the element is absent). |
| constexpr const char *opSingleAfterVariableTemplate = R"Py( |
| @property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
| )Py"; |
| |
| /// Template for an optional element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| constexpr const char *opOneOptionalTemplate = R"Py( |
| @property |
| def {0}(self): |
| return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} else None |
| )Py"; |
| |
| /// Template for the variadic group accessor in the single variadic group case: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| constexpr const char *opOneVariadicTemplate = R"Py( |
| @property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
| )Py"; |
| |
| /// First part of the template for equally-sized variadic group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of variadic groups; |
| /// {3} is the number of non-variadic groups preceding the current group; |
| /// {3} is the number of variadic groups preceding the current group. |
| constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
| @property |
| def {0}(self): |
| start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a single |
| /// element: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
| return self.operation.{0}s[start] |
| )Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a variadic |
| /// group: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
| return self.operation.{0}s[start:start + pg] |
| )Py"; |
| |
| /// Template for an attribute-sized group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position of the group in the group list; |
| /// {3} is a return suffix (expected [0] for single-element, empty for |
| /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
| constexpr const char *opVariadicSegmentTemplate = R"Py( |
| @property |
| def {0}(self): |
| {1}_range = _ods_segmented_accessor( |
| self.operation.{1}s, |
| self.operation.attributes["{1}_segment_sizes"], {2}) |
| return {1}_range{3} |
| )Py"; |
| |
| /// Template for a suffix when accessing an optional element in the |
| /// attribute-sized case: |
| /// {0} is either 'operand' or 'result'; |
| constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
| R"Py([0] if len({0}_range) > 0 else None)Py"; |
| |
| /// Template for an operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the Python type of the attribute; |
| /// {2} os the original name of the attribute. |
| constexpr const char *attributeGetterTemplate = R"Py( |
| @property |
| def {0}(self): |
| return {1}(self.operation.attributes["{2}"]) |
| )Py"; |
| |
| /// Template for an optional operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the Python type of the attribute; |
| /// {2} is the original name of the attribute. |
| constexpr const char *optionalAttributeGetterTemplate = R"Py( |
| @property |
| def {0}(self): |
| if "{2}" not in self.operation.attributes: |
| return None |
| return {1}(self.operation.attributes["{2}"]) |
| )Py"; |
| |
| /// Template for a getter of a unit operation attribute, returns True of the |
| /// unit attribute is present, False otherwise (unit attributes have meaning |
| /// by mere presence): |
| /// {0} is the name of the attribute sanitized for Python, |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeGetterTemplate = R"Py( |
| @property |
| def {0}(self): |
| return "{1}" in self.operation.attributes |
| )Py"; |
| |
| /// Template for an operation attribute setter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is None: |
| raise ValueError("'None' not allowed as value for mandatory attributes") |
| self.operation.attributes["{1}"] = value |
| )Py"; |
| |
| /// Template for a setter of an optional operation attribute, setting to None |
| /// removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *optionalAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is not None: |
| self.operation.attributes["{1}"] = value |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a setter of a unit operation attribute, setting to None or |
| /// False removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if bool(value): |
| self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a deleter of an optional or a unit operation attribute, removes |
| /// the attribute from the operation: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeDeleterTemplate = R"Py( |
| @{0}.deleter |
| def {0}(self): |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| static llvm::cl::OptionCategory |
| clOpPythonBindingCat("Options for -gen-python-op-bindings"); |
| |
| static llvm::cl::opt<std::string> |
| clDialectName("bind-dialect", |
| llvm::cl::desc("The dialect to run the generator for"), |
| llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); |
| |
| using AttributeClasses = DenseMap<StringRef, StringRef>; |
| |
| /// Checks whether `str` is a Python keyword. |
| static bool isPythonKeyword(StringRef str) { |
| static llvm::StringSet<> keywords( |
| {"and", "as", "assert", "break", "class", "continue", |
| "def", "del", "elif", "else", "except", "finally", |
| "for", "from", "global", "if", "import", "in", |
| "is", "lambda", "nonlocal", "not", "or", "pass", |
| "raise", "return", "try", "while", "with", "yield"}); |
| return keywords.contains(str); |
| } |
| |
| /// Checks whether `str` would shadow a generated variable or attribute |
| /// part of the OpView API. |
| static bool isODSReserved(StringRef str) { |
| static llvm::StringSet<> reserved( |
| {"attributes", "create", "context", "ip", "operands", "print", "get_asm", |
| "loc", "verify", "regions", "results", "self", "operation", |
| "DIALECT_NAMESPACE", "OPERATION_NAME"}); |
| return str.startswith("_ods_") || str.endswith("_ods") || |
| reserved.contains(str); |
| } |
| |
| /// Modifies the `name` in a way that it becomes suitable for Python bindings |
| /// (does not change the `name` if it already is suitable) and returns the |
| /// modified version. |
| static std::string sanitizeName(StringRef name) { |
| if (isPythonKeyword(name) || isODSReserved(name)) |
| return (name + "_").str(); |
| return name.str(); |
| } |
| |
| static std::string attrSizedTraitForKind(const char *kind) { |
| return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", |
| llvm::StringRef(kind).take_front().upper(), |
| llvm::StringRef(kind).drop_front()); |
| } |
| |
| /// Emits accessors to "elements" of an Op definition. Currently, the supported |
| /// elements are operands and results, indicated by `kind`, which must be either |
| /// `operand` or `result` and is used verbatim in the emitted code. |
| static void emitElementAccessors( |
| const Operator &op, raw_ostream &os, const char *kind, |
| llvm::function_ref<unsigned(const Operator &)> getNumVariadic, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement) { |
| assert(llvm::is_contained( |
| llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && |
| "unsupported kind"); |
| |
| // Traits indicating how to process variadic elements. |
| std::string sameSizeTrait = |
| llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", |
| llvm::StringRef(kind).take_front().upper(), |
| llvm::StringRef(kind).drop_front()); |
| std::string attrSizedTrait = attrSizedTraitForKind(kind); |
| |
| unsigned numVariadic = getNumVariadic(op); |
| |
| // If there is only one variadic element group, its size can be inferred from |
| // the total number of elements. If there are none, the generation is |
| // straightforward. |
| if (numVariadic <= 1) { |
| bool seenVariableLength = false; |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isVariableLength()) |
| seenVariableLength = true; |
| if (element.name.empty()) |
| continue; |
| if (element.isVariableLength()) { |
| os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate |
| : opOneVariadicTemplate, |
| sanitizeName(element.name), kind, |
| getNumElements(op), i); |
| } else if (seenVariableLength) { |
| os << llvm::formatv(opSingleAfterVariableTemplate, |
| sanitizeName(element.name), kind, |
| getNumElements(op), i); |
| } else { |
| os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, |
| i); |
| } |
| } |
| return; |
| } |
| |
| // Handle the operations where variadic groups have the same size. |
| if (op.getTrait(sameSizeTrait)) { |
| int numPrecedingSimple = 0; |
| int numPrecedingVariadic = 0; |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (!element.name.empty()) { |
| os << llvm::formatv(opVariadicEqualPrefixTemplate, |
| sanitizeName(element.name), kind, numVariadic, |
| numPrecedingSimple, numPrecedingVariadic); |
| os << llvm::formatv(element.isVariableLength() |
| ? opVariadicEqualVariadicTemplate |
| : opVariadicEqualSimpleTemplate, |
| kind); |
| } |
| if (element.isVariableLength()) |
| ++numPrecedingVariadic; |
| else |
| ++numPrecedingSimple; |
| } |
| return; |
| } |
| |
| // Handle the operations where the size of groups (variadic or not) is |
| // provided as an attribute. For non-variadic elements, make sure to return |
| // an element rather than a singleton container. |
| if (op.getTrait(attrSizedTrait)) { |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.name.empty()) |
| continue; |
| std::string trailing; |
| if (!element.isVariableLength()) |
| trailing = "[0]"; |
| else if (element.isOptional()) |
| trailing = std::string( |
| llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); |
| os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), |
| kind, i, trailing); |
| } |
| return; |
| } |
| |
| llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); |
| } |
| |
| /// Free function helpers accessing Operator components. |
| static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
| static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
| return op.getOperand(i); |
| } |
| static int getNumResults(const Operator &op) { return op.getNumResults(); } |
| static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
| return op.getResult(i); |
| } |
| |
| /// Emits accessors to Op operands. |
| static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
| auto getNumVariadic = [](const Operator &oper) { |
| return oper.getNumVariableLengthOperands(); |
| }; |
| emitElementAccessors(op, os, "operand", getNumVariadic, getNumOperands, |
| getOperand); |
| } |
| |
| /// Emits accessors Op results. |
| static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
| auto getNumVariadic = [](const Operator &oper) { |
| return oper.getNumVariableLengthResults(); |
| }; |
| emitElementAccessors(op, os, "result", getNumVariadic, getNumResults, |
| getResult); |
| } |
| |
| /// Emits accessors to Op attributes. |
| static void emitAttributeAccessors(const Operator &op, |
| const AttributeClasses &attributeClasses, |
| raw_ostream &os) { |
| for (const auto &namedAttr : op.getAttributes()) { |
| // Skip "derived" attributes because they are just C++ functions that we |
| // don't currently expose. |
| if (namedAttr.attr.isDerivedAttr()) |
| continue; |
| |
| if (namedAttr.name.empty()) |
| continue; |
| |
| std::string sanitizedName = sanitizeName(namedAttr.name); |
| |
| // Unit attributes are handled specially. |
| if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { |
| os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(attributeDeleterTemplate, sanitizedName, |
| namedAttr.name); |
| continue; |
| } |
| |
| // Other kinds of attributes need a mapping to a Python type. |
| if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) |
| continue; |
| |
| StringRef pythonType = |
| attributeClasses.lookup(namedAttr.attr.getStorageType()); |
| if (namedAttr.attr.isOptional()) { |
| os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, |
| pythonType, namedAttr.name); |
| os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(attributeDeleterTemplate, sanitizedName, |
| namedAttr.name); |
| } else { |
| os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, |
| namedAttr.name); |
| os << llvm::formatv(attributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| // Non-optional attributes cannot be deleted. |
| } |
| } |
| } |
| |
| /// Template for the default auto-generated builder. |
| /// {0} is a comma-separated list of builder arguments, including the trailing |
| /// `loc` and `ip`; |
| /// {1} is the code populating `operands`, `results` and `attributes` fields. |
| constexpr const char *initTemplate = R"Py( |
| def __init__(self, {0}): |
| operands = [] |
| results = [] |
| attributes = {{} |
| {1} |
| super().__init__(self.build_generic( |
| attributes=attributes, results=results, operands=operands, |
| loc=loc, ip=ip)) |
| )Py"; |
| |
| /// Template for appending a single element to the operand/result list. |
| /// {0} is either 'operand' or 'result'; |
| /// {1} is the field name. |
| constexpr const char *singleElementAppendTemplate = "{0}s.append({1})"; |
| |
| /// Template for appending an optional element to the operand/result list. |
| /// {0} is either 'operand' or 'result'; |
| /// {1} is the field name. |
| constexpr const char *optionalAppendTemplate = |
| "if {1} is not None: {0}s.append({1})"; |
| |
| /// Template for appending a a list of elements to the operand/result list. |
| /// {0} is either 'operand' or 'result'; |
| /// {1} is the field name. |
| constexpr const char *multiElementAppendTemplate = "{0}s.extend({1})"; |
| |
| /// Template for setting an attribute in the operation builder. |
| /// {0} is the attribute name; |
| /// {1} is the builder argument name. |
| constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; |
| |
| /// Template for setting an optional attribute in the operation builder. |
| /// {0} is the attribute name; |
| /// {1} is the builder argument name. |
| constexpr const char *initOptionalAttributeTemplate = |
| R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; |
| |
| constexpr const char *initUnitAttributeTemplate = |
| R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
| _ods_get_default_loc_context(loc)))Py"; |
| |
| /// Populates `builderArgs` with the Python-compatible names of builder function |
| /// arguments, first the results, then the intermixed attributes and operands in |
| /// the same order as they appear in the `arguments` field of the op definition. |
| /// Additionally, `operandNames` is populated with names of operands in their |
| /// order of appearance. |
| static void |
| populateBuilderArgs(const Operator &op, |
| llvm::SmallVectorImpl<std::string> &builderArgs, |
| llvm::SmallVectorImpl<std::string> &operandNames) { |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| std::string name = op.getResultName(i).str(); |
| if (name.empty()) { |
| if (op.getNumResults() == 1) { |
| // Special case for one result, make the default name be 'result' |
| // to properly match the built-in result accessor. |
| name = "result"; |
| } else { |
| name = llvm::formatv("_gen_res_{0}", i); |
| } |
| } |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| } |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| std::string name = op.getArgName(i).str(); |
| if (name.empty()) |
| name = llvm::formatv("_gen_arg_{0}", i); |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| if (!op.getArg(i).is<NamedAttribute *>()) |
| operandNames.push_back(name); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up operation attributes. `argNames` is expected to contain |
| /// the names of builder arguments that correspond to op arguments, i.e. to the |
| /// operands and attributes in the same order as they appear in the `arguments` |
| /// field. |
| static void |
| populateBuilderLinesAttr(const Operator &op, |
| llvm::ArrayRef<std::string> argNames, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| Argument arg = op.getArg(i); |
| auto *attribute = arg.dyn_cast<NamedAttribute *>(); |
| if (!attribute) |
| continue; |
| |
| // Unit attributes are handled specially. |
| if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { |
| builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, |
| attribute->name, argNames[i])); |
| continue; |
| } |
| |
| builderLines.push_back(llvm::formatv(attribute->attr.isOptional() |
| ? initOptionalAttributeTemplate |
| : initAttributeTemplate, |
| attribute->name, argNames[i])); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder. `kind` must be either "operand" or "result". `names` contains the |
| /// names of init arguments that correspond to the elements. |
| static void populateBuilderLines( |
| const Operator &op, const char *kind, llvm::ArrayRef<std::string> names, |
| llvm::SmallVectorImpl<std::string> &builderLines, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement) { |
| bool sizedSegments = op.getTrait(attrSizedTraitForKind(kind)) != nullptr; |
| |
| // For each element, find or generate a name. |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| std::string name = names[i]; |
| |
| // Choose the formatting string based on the element kind. |
| llvm::StringRef formatString; |
| if (!element.isVariableLength()) { |
| formatString = singleElementAppendTemplate; |
| } else if (element.isOptional()) { |
| formatString = optionalAppendTemplate; |
| } else { |
| assert(element.isVariadic() && "unhandled element group type"); |
| // If emitting with sizedSegments, then we add the actual list typed |
| // element using the singleElementAppendTemplate. Otherwise, we extend |
| // the actual operands. |
| if (sizedSegments) { |
| // Append the list as is. |
| formatString = singleElementAppendTemplate; |
| } else { |
| // Append the list elements. |
| formatString = multiElementAppendTemplate; |
| } |
| } |
| |
| // Add the lines. |
| builderLines.push_back(llvm::formatv(formatString.data(), kind, name)); |
| } |
| } |
| |
| /// Emits a default builder constructing an operation from the list of its |
| /// result types, followed by a list of its operands. |
| static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { |
| // If we are asked to skip default builders, comply. |
| if (op.skipDefaultBuilders()) |
| return; |
| |
| llvm::SmallVector<std::string, 8> builderArgs; |
| llvm::SmallVector<std::string, 8> builderLines; |
| llvm::SmallVector<std::string, 4> operandArgNames; |
| builderArgs.reserve(op.getNumOperands() + op.getNumResults() + |
| op.getNumNativeAttributes()); |
| populateBuilderArgs(op, builderArgs, operandArgNames); |
| populateBuilderLines( |
| op, "result", |
| llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), |
| builderLines, getNumResults, getResult); |
| populateBuilderLines(op, "operand", operandArgNames, builderLines, |
| getNumOperands, getOperand); |
| populateBuilderLinesAttr( |
| op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), |
| builderLines); |
| |
| builderArgs.push_back("*"); |
| builderArgs.push_back("loc=None"); |
| builderArgs.push_back("ip=None"); |
| os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), |
| llvm::join(builderLines, "\n ")); |
| } |
| |
| static void constructAttributeMapping(const llvm::RecordKeeper &records, |
| AttributeClasses &attributeClasses) { |
| for (const llvm::Record *rec : |
| records.getAllDerivedDefinitions("PythonAttr")) { |
| attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), |
| rec->getValueAsString("pythonType").trim()); |
| } |
| } |
| |
| static void emitSegmentSpec( |
| const Operator &op, const char *kind, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement, |
| raw_ostream &os) { |
| std::string segmentSpec("["); |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isVariableLength()) { |
| segmentSpec.append("-1,"); |
| } else if (element.isOptional()) { |
| segmentSpec.append("0,"); |
| } else { |
| segmentSpec.append("1,"); |
| } |
| } |
| segmentSpec.append("]"); |
| |
| os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); |
| } |
| |
| static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
| // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
| // Note that the base OpView class defines this as (0, True). |
| unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
| os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, |
| op.hasNoVariadicRegions() ? "True" : "False"); |
| } |
| |
| /// Emits bindings for a specific Op to the given output stream. |
| static void emitOpBindings(const Operator &op, |
| const AttributeClasses &attributeClasses, |
| raw_ostream &os) { |
| os << llvm::formatv(opClassTemplate, op.getCppClassName(), |
| op.getOperationName()); |
| |
| // Sized segments. |
| if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { |
| emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); |
| } |
| if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { |
| emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); |
| } |
| |
| emitRegionAttributes(op, os); |
| emitDefaultOpBuilder(op, os); |
| emitOperandAccessors(op, os); |
| emitAttributeAccessors(op, attributeClasses, os); |
| emitResultAccessors(op, os); |
| } |
| |
| /// Emits bindings for the dialect specified in the command line, including file |
| /// headers and utilities. Returns `false` on success to comply with Tablegen |
| /// registration requirements. |
| static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { |
| if (clDialectName.empty()) |
| llvm::PrintFatalError("dialect name not provided"); |
| |
| AttributeClasses attributeClasses; |
| constructAttributeMapping(records, attributeClasses); |
| |
| os << llvm::formatv(fileHeader, clDialectName.getValue()); |
| os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); |
| |
| if (clDialectName == "builtin") |
| clDialectName = ""; |
| |
| for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { |
| Operator op(rec); |
| if (op.getDialectName() == clDialectName.getValue()) |
| emitOpBindings(op, attributeClasses, os); |
| } |
| return false; |
| } |
| |
| static GenRegistration |
| genPythonBindings("gen-python-op-bindings", |
| "Generate Python bindings for MLIR Ops", &emitAllOps); |