| //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python |
| // binding classes wrapping a generic operation API. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/TableGen/GenInfo.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Record.h" |
| |
| using namespace mlir; |
| using namespace mlir::tblgen; |
| |
| /// File header and includes. |
| /// {0} is the dialect namespace. |
| constexpr const char *fileHeader = R"Py( |
| # Autogenerated by mlir-tblgen; don't manually edit. |
| |
| from ._ods_common import _cext as _ods_cext |
| from ._ods_common import extend_opview_class as _ods_extend_opview_class, segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values |
| _ods_ir = _ods_cext.ir |
| |
| try: |
| from . import _{0}_ops_ext as _ods_ext_module |
| except ImportError: |
| _ods_ext_module = None |
| |
| import builtins |
| |
| )Py"; |
| |
| /// Template for dialect class: |
| /// {0} is the dialect namespace. |
| constexpr const char *dialectClassTemplate = R"Py( |
| @_ods_cext.register_dialect |
| class _Dialect(_ods_ir.Dialect): |
| DIALECT_NAMESPACE = "{0}" |
| pass |
| |
| )Py"; |
| |
| /// Template for operation class: |
| /// {0} is the Python class name; |
| /// {1} is the operation name. |
| constexpr const char *opClassTemplate = R"Py( |
| @_ods_cext.register_operation(_Dialect) |
| @_ods_extend_opview_class(_ods_ext_module) |
| class {0}(_ods_ir.OpView): |
| OPERATION_NAME = "{1}" |
| )Py"; |
| |
| /// Template for class level declarations of operand and result |
| /// segment specs. |
| /// {0} is either "OPERAND" or "RESULT" |
| /// {1} is the segment spec |
| /// Each segment spec is either None (default) or an array of integers |
| /// where: |
| /// 1 = single element (expect non sequence operand/result) |
| /// 0 = optional element (expect a value or None) |
| /// -1 = operand/result is a sequence corresponding to a variadic |
| constexpr const char *opClassSizedSegmentsTemplate = R"Py( |
| _ODS_{0}_SEGMENTS = {1} |
| )Py"; |
| |
| /// Template for class level declarations of the _ODS_REGIONS spec: |
| /// {0} is the minimum number of regions |
| /// {1} is the Python bool literal for hasNoVariadicRegions |
| constexpr const char *opClassRegionSpecTemplate = R"Py( |
| _ODS_REGIONS = ({0}, {1}) |
| )Py"; |
| |
| /// Template for single-element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position in the element list. |
| constexpr const char *opSingleTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return self.operation.{1}s[{2}] |
| )Py"; |
| |
| /// Template for single-element accessor after a variable-length group: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| /// This works for both a single variadic group (non-negative length) and an |
| /// single optional element (zero length if the element is absent). |
| constexpr const char *opSingleAfterVariableTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] |
| )Py"; |
| |
| /// Template for an optional element accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| /// This works if we have only one variable-length group (and it's the optional |
| /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is |
| /// smaller than the total number of groups. |
| constexpr const char *opOneOptionalTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] |
| )Py"; |
| |
| /// Template for the variadic group accessor in the single variadic group case: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of element groups; |
| /// {3} is the position of the current group in the group list. |
| constexpr const char *opOneVariadicTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 |
| return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] |
| )Py"; |
| |
| /// First part of the template for equally-sized variadic group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the total number of variadic groups; |
| /// {3} is the number of non-variadic groups preceding the current group; |
| /// {3} is the number of variadic groups preceding the current group. |
| constexpr const char *opVariadicEqualPrefixTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| start, pg = _ods_equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a single |
| /// element: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualSimpleTemplate = R"Py( |
| return self.operation.{0}s[start] |
| )Py"; |
| |
| /// Second part of the template for equally-sized case, accessing a variadic |
| /// group: |
| /// {0} is either 'operand' or 'result'. |
| constexpr const char *opVariadicEqualVariadicTemplate = R"Py( |
| return self.operation.{0}s[start:start + pg] |
| )Py"; |
| |
| /// Template for an attribute-sized group accessor: |
| /// {0} is the name of the accessor; |
| /// {1} is either 'operand' or 'result'; |
| /// {2} is the position of the group in the group list; |
| /// {3} is a return suffix (expected [0] for single-element, empty for |
| /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). |
| constexpr const char *opVariadicSegmentTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| {1}_range = _ods_segmented_accessor( |
| self.operation.{1}s, |
| self.operation.attributes["{1}_segment_sizes"], {2}) |
| return {1}_range{3} |
| )Py"; |
| |
| /// Template for a suffix when accessing an optional element in the |
| /// attribute-sized case: |
| /// {0} is either 'operand' or 'result'; |
| constexpr const char *opVariadicSegmentOptionalTrailingTemplate = |
| R"Py([0] if len({0}_range) > 0 else None)Py"; |
| |
| /// Template for an operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the Python type of the attribute; |
| /// {2} os the original name of the attribute. |
| constexpr const char *attributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return {1}(self.operation.attributes["{2}"]) |
| )Py"; |
| |
| /// Template for an optional operation attribute getter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the Python type of the attribute; |
| /// {2} is the original name of the attribute. |
| constexpr const char *optionalAttributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| if "{2}" not in self.operation.attributes: |
| return None |
| return {1}(self.operation.attributes["{2}"]) |
| )Py"; |
| |
| /// Template for a getter of a unit operation attribute, returns True of the |
| /// unit attribute is present, False otherwise (unit attributes have meaning |
| /// by mere presence): |
| /// {0} is the name of the attribute sanitized for Python, |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeGetterTemplate = R"Py( |
| @builtins.property |
| def {0}(self): |
| return "{1}" in self.operation.attributes |
| )Py"; |
| |
| /// Template for an operation attribute setter: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is None: |
| raise ValueError("'None' not allowed as value for mandatory attributes") |
| self.operation.attributes["{1}"] = value |
| )Py"; |
| |
| /// Template for a setter of an optional operation attribute, setting to None |
| /// removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *optionalAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if value is not None: |
| self.operation.attributes["{1}"] = value |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a setter of a unit operation attribute, setting to None or |
| /// False removes the attribute: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *unitAttributeSetterTemplate = R"Py( |
| @{0}.setter |
| def {0}(self, value): |
| if bool(value): |
| self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get() |
| elif "{1}" in self.operation.attributes: |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| /// Template for a deleter of an optional or a unit operation attribute, removes |
| /// the attribute from the operation: |
| /// {0} is the name of the attribute sanitized for Python; |
| /// {1} is the original name of the attribute. |
| constexpr const char *attributeDeleterTemplate = R"Py( |
| @{0}.deleter |
| def {0}(self): |
| del self.operation.attributes["{1}"] |
| )Py"; |
| |
| constexpr const char *regionAccessorTemplate = R"PY( |
| @builtins.property |
| def {0}(self): |
| return self.regions[{1}] |
| )PY"; |
| |
| static llvm::cl::OptionCategory |
| clOpPythonBindingCat("Options for -gen-python-op-bindings"); |
| |
| static llvm::cl::opt<std::string> |
| clDialectName("bind-dialect", |
| llvm::cl::desc("The dialect to run the generator for"), |
| llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); |
| |
| using AttributeClasses = DenseMap<StringRef, StringRef>; |
| |
| /// Checks whether `str` is a Python keyword. |
| static bool isPythonKeyword(StringRef str) { |
| static llvm::StringSet<> keywords( |
| {"and", "as", "assert", "break", "class", "continue", |
| "def", "del", "elif", "else", "except", "finally", |
| "for", "from", "global", "if", "import", "in", |
| "is", "lambda", "nonlocal", "not", "or", "pass", |
| "raise", "return", "try", "while", "with", "yield"}); |
| return keywords.contains(str); |
| } |
| |
| /// Checks whether `str` would shadow a generated variable or attribute |
| /// part of the OpView API. |
| static bool isODSReserved(StringRef str) { |
| static llvm::StringSet<> reserved( |
| {"attributes", "create", "context", "ip", "operands", "print", "get_asm", |
| "loc", "verify", "regions", "results", "self", "operation", |
| "DIALECT_NAMESPACE", "OPERATION_NAME"}); |
| return str.startswith("_ods_") || str.endswith("_ods") || |
| reserved.contains(str); |
| } |
| |
| /// Modifies the `name` in a way that it becomes suitable for Python bindings |
| /// (does not change the `name` if it already is suitable) and returns the |
| /// modified version. |
| static std::string sanitizeName(StringRef name) { |
| if (isPythonKeyword(name) || isODSReserved(name)) |
| return (name + "_").str(); |
| return name.str(); |
| } |
| |
| static std::string attrSizedTraitForKind(const char *kind) { |
| return llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", |
| llvm::StringRef(kind).take_front().upper(), |
| llvm::StringRef(kind).drop_front()); |
| } |
| |
| /// Emits accessors to "elements" of an Op definition. Currently, the supported |
| /// elements are operands and results, indicated by `kind`, which must be either |
| /// `operand` or `result` and is used verbatim in the emitted code. |
| static void emitElementAccessors( |
| const Operator &op, raw_ostream &os, const char *kind, |
| llvm::function_ref<unsigned(const Operator &)> getNumVariableLength, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement) { |
| assert(llvm::is_contained( |
| llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) && |
| "unsupported kind"); |
| |
| // Traits indicating how to process variadic elements. |
| std::string sameSizeTrait = |
| llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", |
| llvm::StringRef(kind).take_front().upper(), |
| llvm::StringRef(kind).drop_front()); |
| std::string attrSizedTrait = attrSizedTraitForKind(kind); |
| |
| unsigned numVariableLength = getNumVariableLength(op); |
| |
| // If there is only one variable-length element group, its size can be |
| // inferred from the total number of elements. If there are none, the |
| // generation is straightforward. |
| if (numVariableLength <= 1) { |
| bool seenVariableLength = false; |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isVariableLength()) |
| seenVariableLength = true; |
| if (element.name.empty()) |
| continue; |
| if (element.isVariableLength()) { |
| os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate |
| : opOneVariadicTemplate, |
| sanitizeName(element.name), kind, |
| getNumElements(op), i); |
| } else if (seenVariableLength) { |
| os << llvm::formatv(opSingleAfterVariableTemplate, |
| sanitizeName(element.name), kind, |
| getNumElements(op), i); |
| } else { |
| os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, |
| i); |
| } |
| } |
| return; |
| } |
| |
| // Handle the operations where variadic groups have the same size. |
| if (op.getTrait(sameSizeTrait)) { |
| int numPrecedingSimple = 0; |
| int numPrecedingVariadic = 0; |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (!element.name.empty()) { |
| os << llvm::formatv(opVariadicEqualPrefixTemplate, |
| sanitizeName(element.name), kind, numVariableLength, |
| numPrecedingSimple, numPrecedingVariadic); |
| os << llvm::formatv(element.isVariableLength() |
| ? opVariadicEqualVariadicTemplate |
| : opVariadicEqualSimpleTemplate, |
| kind); |
| } |
| if (element.isVariableLength()) |
| ++numPrecedingVariadic; |
| else |
| ++numPrecedingSimple; |
| } |
| return; |
| } |
| |
| // Handle the operations where the size of groups (variadic or not) is |
| // provided as an attribute. For non-variadic elements, make sure to return |
| // an element rather than a singleton container. |
| if (op.getTrait(attrSizedTrait)) { |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.name.empty()) |
| continue; |
| std::string trailing; |
| if (!element.isVariableLength()) |
| trailing = "[0]"; |
| else if (element.isOptional()) |
| trailing = std::string( |
| llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); |
| os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), |
| kind, i, trailing); |
| } |
| return; |
| } |
| |
| llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); |
| } |
| |
| /// Free function helpers accessing Operator components. |
| static int getNumOperands(const Operator &op) { return op.getNumOperands(); } |
| static const NamedTypeConstraint &getOperand(const Operator &op, int i) { |
| return op.getOperand(i); |
| } |
| static int getNumResults(const Operator &op) { return op.getNumResults(); } |
| static const NamedTypeConstraint &getResult(const Operator &op, int i) { |
| return op.getResult(i); |
| } |
| |
| /// Emits accessors to Op operands. |
| static void emitOperandAccessors(const Operator &op, raw_ostream &os) { |
| auto getNumVariableLengthOperands = [](const Operator &oper) { |
| return oper.getNumVariableLengthOperands(); |
| }; |
| emitElementAccessors(op, os, "operand", getNumVariableLengthOperands, |
| getNumOperands, getOperand); |
| } |
| |
| /// Emits accessors Op results. |
| static void emitResultAccessors(const Operator &op, raw_ostream &os) { |
| auto getNumVariableLengthResults = [](const Operator &oper) { |
| return oper.getNumVariableLengthResults(); |
| }; |
| emitElementAccessors(op, os, "result", getNumVariableLengthResults, |
| getNumResults, getResult); |
| } |
| |
| /// Emits accessors to Op attributes. |
| static void emitAttributeAccessors(const Operator &op, |
| const AttributeClasses &attributeClasses, |
| raw_ostream &os) { |
| for (const auto &namedAttr : op.getAttributes()) { |
| // Skip "derived" attributes because they are just C++ functions that we |
| // don't currently expose. |
| if (namedAttr.attr.isDerivedAttr()) |
| continue; |
| |
| if (namedAttr.name.empty()) |
| continue; |
| |
| std::string sanitizedName = sanitizeName(namedAttr.name); |
| |
| // Unit attributes are handled specially. |
| if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { |
| os << llvm::formatv(unitAttributeGetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(unitAttributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(attributeDeleterTemplate, sanitizedName, |
| namedAttr.name); |
| continue; |
| } |
| |
| // Other kinds of attributes need a mapping to a Python type. |
| if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) |
| continue; |
| |
| StringRef pythonType = |
| attributeClasses.lookup(namedAttr.attr.getStorageType()); |
| if (namedAttr.attr.isOptional()) { |
| os << llvm::formatv(optionalAttributeGetterTemplate, sanitizedName, |
| pythonType, namedAttr.name); |
| os << llvm::formatv(optionalAttributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| os << llvm::formatv(attributeDeleterTemplate, sanitizedName, |
| namedAttr.name); |
| } else { |
| os << llvm::formatv(attributeGetterTemplate, sanitizedName, pythonType, |
| namedAttr.name); |
| os << llvm::formatv(attributeSetterTemplate, sanitizedName, |
| namedAttr.name); |
| // Non-optional attributes cannot be deleted. |
| } |
| } |
| } |
| |
| /// Template for the default auto-generated builder. |
| /// {0} is a comma-separated list of builder arguments, including the trailing |
| /// `loc` and `ip`; |
| /// {1} is the code populating `operands`, `results` and `attributes`, |
| /// `successors` fields. |
| constexpr const char *initTemplate = R"Py( |
| def __init__(self, {0}): |
| operands = [] |
| results = [] |
| attributes = {{} |
| regions = None |
| {1} |
| super().__init__(self.build_generic( |
| attributes=attributes, results=results, operands=operands, |
| successors=_ods_successors, regions=regions, loc=loc, ip=ip)) |
| )Py"; |
| |
| /// Template for appending a single element to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *singleOperandAppendTemplate = |
| "operands.append(_get_op_result_or_value({0}))"; |
| constexpr const char *singleResultAppendTemplate = "results.append({0})"; |
| |
| /// Template for appending an optional element to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *optionalAppendOperandTemplate = |
| "if {0} is not None: operands.append(_get_op_result_or_value({0}))"; |
| constexpr const char *optionalAppendAttrSizedOperandsTemplate = |
| "operands.append(_get_op_result_or_value({0}) if {0} is not None else " |
| "None)"; |
| constexpr const char *optionalAppendResultTemplate = |
| "if {0} is not None: results.append({0})"; |
| |
| /// Template for appending a list of elements to the operand/result list. |
| /// {0} is the field name. |
| constexpr const char *multiOperandAppendTemplate = |
| "operands.extend(_get_op_results_or_values({0}))"; |
| constexpr const char *multiOperandAppendPackTemplate = |
| "operands.append(_get_op_results_or_values({0}))"; |
| constexpr const char *multiResultAppendTemplate = "results.extend({0})"; |
| |
| /// Template for setting an attribute in the operation builder. |
| /// {0} is the attribute name; |
| /// {1} is the builder argument name. |
| constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; |
| |
| /// Template for setting an optional attribute in the operation builder. |
| /// {0} is the attribute name; |
| /// {1} is the builder argument name. |
| constexpr const char *initOptionalAttributeTemplate = |
| R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; |
| |
| constexpr const char *initUnitAttributeTemplate = |
| R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( |
| _ods_get_default_loc_context(loc)))Py"; |
| |
| /// Template to initialize the successors list in the builder if there are any |
| /// successors. |
| /// {0} is the value to initialize the successors list to. |
| constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; |
| |
| /// Template to append or extend the list of successors in the builder. |
| /// {0} is the list method ('append' or 'extend'); |
| /// {1} is the value to add. |
| constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py"; |
| |
| /// Returns true if the SameArgumentAndResultTypes trait can be used to infer |
| /// result types of the given operation. |
| static bool hasSameArgumentAndResultTypes(const Operator &op) { |
| return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") && |
| op.getNumVariableLengthResults() == 0; |
| } |
| |
| /// Returns true if the FirstAttrDerivedResultType trait can be used to infer |
| /// result types of the given operation. |
| static bool hasFirstAttrDerivedResultTypes(const Operator &op) { |
| return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") && |
| op.getNumVariableLengthResults() == 0; |
| } |
| |
| /// Returns true if the InferTypeOpInterface can be used to infer result types |
| /// of the given operation. |
| static bool hasInferTypeInterface(const Operator &op) { |
| return op.getTrait("::mlir::InferTypeOpInterface::Trait") && |
| op.getNumRegions() == 0; |
| } |
| |
| /// Returns true if there is a trait or interface that can be used to infer |
| /// result types of the given operation. |
| static bool canInferType(const Operator &op) { |
| return hasSameArgumentAndResultTypes(op) || |
| hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op); |
| } |
| |
| /// Populates `builderArgs` with result names if the builder is expected to |
| /// accept them as arguments. |
| static void |
| populateBuilderArgsResults(const Operator &op, |
| llvm::SmallVectorImpl<std::string> &builderArgs) { |
| if (canInferType(op)) |
| return; |
| |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| std::string name = op.getResultName(i).str(); |
| if (name.empty()) { |
| if (op.getNumResults() == 1) { |
| // Special case for one result, make the default name be 'result' |
| // to properly match the built-in result accessor. |
| name = "result"; |
| } else { |
| name = llvm::formatv("_gen_res_{0}", i); |
| } |
| } |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| } |
| } |
| |
| /// Populates `builderArgs` with the Python-compatible names of builder function |
| /// arguments using intermixed attributes and operands in the same order as they |
| /// appear in the `arguments` field of the op definition. Additionally, |
| /// `operandNames` is populated with names of operands in their order of |
| /// appearance. |
| static void |
| populateBuilderArgs(const Operator &op, |
| llvm::SmallVectorImpl<std::string> &builderArgs, |
| llvm::SmallVectorImpl<std::string> &operandNames, |
| llvm::SmallVectorImpl<std::string> &successorArgNames) { |
| |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| std::string name = op.getArgName(i).str(); |
| if (name.empty()) |
| name = llvm::formatv("_gen_arg_{0}", i); |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| if (!op.getArg(i).is<NamedAttribute *>()) |
| operandNames.push_back(name); |
| } |
| |
| for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { |
| NamedSuccessor successor = op.getSuccessor(i); |
| std::string name = std::string(successor.name); |
| if (name.empty()) |
| name = llvm::formatv("_gen_successor_{0}", i); |
| name = sanitizeName(name); |
| builderArgs.push_back(name); |
| successorArgNames.push_back(name); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up operation attributes. `argNames` is expected to contain |
| /// the names of builder arguments that correspond to op arguments, i.e. to the |
| /// operands and attributes in the same order as they appear in the `arguments` |
| /// field. |
| static void |
| populateBuilderLinesAttr(const Operator &op, |
| llvm::ArrayRef<std::string> argNames, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| for (int i = 0, e = op.getNumArgs(); i < e; ++i) { |
| Argument arg = op.getArg(i); |
| auto *attribute = arg.dyn_cast<NamedAttribute *>(); |
| if (!attribute) |
| continue; |
| |
| // Unit attributes are handled specially. |
| if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { |
| builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, |
| attribute->name, argNames[i])); |
| continue; |
| } |
| |
| builderLines.push_back(llvm::formatv(attribute->attr.isOptional() |
| ? initOptionalAttributeTemplate |
| : initAttributeTemplate, |
| attribute->name, argNames[i])); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up successors. successorArgNames is expected to correspond |
| /// to the Python argument name for each successor on the op. |
| static void populateBuilderLinesSuccessors( |
| const Operator &op, llvm::ArrayRef<std::string> successorArgNames, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| if (successorArgNames.empty()) { |
| builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "None")); |
| return; |
| } |
| |
| builderLines.push_back(llvm::formatv(initSuccessorsTemplate, "[]")); |
| for (int i = 0, e = successorArgNames.size(); i < e; ++i) { |
| auto &argName = successorArgNames[i]; |
| const NamedSuccessor &successor = op.getSuccessor(i); |
| builderLines.push_back( |
| llvm::formatv(addSuccessorTemplate, |
| successor.isVariadic() ? "extend" : "append", argName)); |
| } |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up op operands. |
| static void |
| populateBuilderLinesOperand(const Operator &op, |
| llvm::ArrayRef<std::string> names, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr; |
| |
| // For each element, find or generate a name. |
| for (int i = 0, e = op.getNumOperands(); i < e; ++i) { |
| const NamedTypeConstraint &element = op.getOperand(i); |
| std::string name = names[i]; |
| |
| // Choose the formatting string based on the element kind. |
| llvm::StringRef formatString; |
| if (!element.isVariableLength()) { |
| formatString = singleOperandAppendTemplate; |
| } else if (element.isOptional()) { |
| if (sizedSegments) { |
| formatString = optionalAppendAttrSizedOperandsTemplate; |
| } else { |
| formatString = optionalAppendOperandTemplate; |
| } |
| } else { |
| assert(element.isVariadic() && "unhandled element group type"); |
| // If emitting with sizedSegments, then we add the actual list-typed |
| // element. Otherwise, we extend the actual operands. |
| if (sizedSegments) { |
| formatString = multiOperandAppendPackTemplate; |
| } else { |
| formatString = multiOperandAppendTemplate; |
| } |
| } |
| |
| builderLines.push_back(llvm::formatv(formatString.data(), name)); |
| } |
| } |
| |
| /// Python code template for deriving the operation result types from its |
| /// attribute: |
| /// - {0} is the name of the attribute from which to derive the types. |
| constexpr const char *deriveTypeFromAttrTemplate = |
| R"PY(_ods_result_type_source_attr = attributes["{0}"] |
| _ods_derived_result_type = ( |
| _ods_ir.TypeAttr(_ods_result_type_source_attr).value |
| if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else |
| _ods_result_type_source_attr.type))PY"; |
| |
| /// Python code template appending {0} type {1} times to the results list. |
| constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; |
| |
| /// Python code template for inferring the operation results using the |
| /// corresponding interface: |
| /// - {0} is the name of the class for which the types are inferred. |
| constexpr const char *inferTypeInterfaceTemplate = |
| R"PY(_ods_context = _ods_get_default_loc_context(loc) |
| results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes( |
| operands=operands, |
| attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context), |
| context=_ods_context, |
| loc=loc) |
| )PY"; |
| |
| /// Appends the given multiline string as individual strings into |
| /// `builderLines`. |
| static void appendLineByLine(StringRef string, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| |
| std::pair<StringRef, StringRef> split = std::make_pair(string, string); |
| do { |
| split = split.second.split('\n'); |
| builderLines.push_back(split.first.str()); |
| } while (!split.second.empty()); |
| } |
| |
| /// Populates `builderLines` with additional lines that are required in the |
| /// builder to set up op results. |
| static void |
| populateBuilderLinesResult(const Operator &op, |
| llvm::ArrayRef<std::string> names, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; |
| |
| if (hasSameArgumentAndResultTypes(op)) { |
| builderLines.push_back(llvm::formatv( |
| appendSameResultsTemplate, "operands[0].type", op.getNumResults())); |
| return; |
| } |
| |
| if (hasFirstAttrDerivedResultTypes(op)) { |
| const NamedAttribute &firstAttr = op.getAttribute(0); |
| assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " |
| "from which the type is derived"); |
| appendLineByLine( |
| llvm::formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), |
| builderLines); |
| builderLines.push_back(llvm::formatv(appendSameResultsTemplate, |
| "_ods_derived_result_type", |
| op.getNumResults())); |
| return; |
| } |
| |
| if (hasInferTypeInterface(op)) { |
| appendLineByLine( |
| llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(), |
| builderLines); |
| return; |
| } |
| |
| // For each element, find or generate a name. |
| for (int i = 0, e = op.getNumResults(); i < e; ++i) { |
| const NamedTypeConstraint &element = op.getResult(i); |
| std::string name = names[i]; |
| |
| // Choose the formatting string based on the element kind. |
| llvm::StringRef formatString; |
| if (!element.isVariableLength()) { |
| formatString = singleResultAppendTemplate; |
| } else if (element.isOptional()) { |
| formatString = optionalAppendResultTemplate; |
| } else { |
| assert(element.isVariadic() && "unhandled element group type"); |
| // If emitting with sizedSegments, then we add the actual list-typed |
| // element. Otherwise, we extend the actual operands. |
| if (sizedSegments) { |
| formatString = singleResultAppendTemplate; |
| } else { |
| formatString = multiResultAppendTemplate; |
| } |
| } |
| |
| builderLines.push_back(llvm::formatv(formatString.data(), name)); |
| } |
| } |
| |
| /// If the operation has variadic regions, adds a builder argument to specify |
| /// the number of those regions and builder lines to forward it to the generic |
| /// constructor. |
| static void |
| populateBuilderRegions(const Operator &op, |
| llvm::SmallVectorImpl<std::string> &builderArgs, |
| llvm::SmallVectorImpl<std::string> &builderLines) { |
| if (op.hasNoVariadicRegions()) |
| return; |
| |
| // This is currently enforced when Operator is constructed. |
| assert(op.getNumVariadicRegions() == 1 && |
| op.getRegion(op.getNumRegions() - 1).isVariadic() && |
| "expected the last region to be varidic"); |
| |
| const NamedRegion ®ion = op.getRegion(op.getNumRegions() - 1); |
| std::string name = |
| ("num_" + region.name.take_front().lower() + region.name.drop_front()) |
| .str(); |
| builderArgs.push_back(name); |
| builderLines.push_back( |
| llvm::formatv("regions = {0} + {1}", op.getNumRegions() - 1, name)); |
| } |
| |
| /// Emits a default builder constructing an operation from the list of its |
| /// result types, followed by a list of its operands. |
| static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { |
| // If we are asked to skip default builders, comply. |
| if (op.skipDefaultBuilders()) |
| return; |
| |
| llvm::SmallVector<std::string> builderArgs; |
| llvm::SmallVector<std::string> builderLines; |
| llvm::SmallVector<std::string> operandArgNames; |
| llvm::SmallVector<std::string> successorArgNames; |
| builderArgs.reserve(op.getNumOperands() + op.getNumResults() + |
| op.getNumNativeAttributes() + op.getNumSuccessors()); |
| populateBuilderArgsResults(op, builderArgs); |
| size_t numResultArgs = builderArgs.size(); |
| populateBuilderArgs(op, builderArgs, operandArgNames, successorArgNames); |
| |
| populateBuilderLinesOperand(op, operandArgNames, builderLines); |
| populateBuilderLinesAttr( |
| op, llvm::makeArrayRef(builderArgs).drop_front(numResultArgs), |
| builderLines); |
| populateBuilderLinesResult( |
| op, llvm::makeArrayRef(builderArgs).take_front(numResultArgs), |
| builderLines); |
| populateBuilderLinesSuccessors(op, successorArgNames, builderLines); |
| populateBuilderRegions(op, builderArgs, builderLines); |
| |
| builderArgs.push_back("*"); |
| builderArgs.push_back("loc=None"); |
| builderArgs.push_back("ip=None"); |
| os << llvm::formatv(initTemplate, llvm::join(builderArgs, ", "), |
| llvm::join(builderLines, "\n ")); |
| } |
| |
| static void constructAttributeMapping(const llvm::RecordKeeper &records, |
| AttributeClasses &attributeClasses) { |
| for (const llvm::Record *rec : |
| records.getAllDerivedDefinitions("PythonAttr")) { |
| attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), |
| rec->getValueAsString("pythonType").trim()); |
| } |
| } |
| |
| static void emitSegmentSpec( |
| const Operator &op, const char *kind, |
| llvm::function_ref<int(const Operator &)> getNumElements, |
| llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)> |
| getElement, |
| raw_ostream &os) { |
| std::string segmentSpec("["); |
| for (int i = 0, e = getNumElements(op); i < e; ++i) { |
| const NamedTypeConstraint &element = getElement(op, i); |
| if (element.isOptional()) { |
| segmentSpec.append("0,"); |
| } else if (element.isVariadic()) { |
| segmentSpec.append("-1,"); |
| } else { |
| segmentSpec.append("1,"); |
| } |
| } |
| segmentSpec.append("]"); |
| |
| os << llvm::formatv(opClassSizedSegmentsTemplate, kind, segmentSpec); |
| } |
| |
| static void emitRegionAttributes(const Operator &op, raw_ostream &os) { |
| // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions). |
| // Note that the base OpView class defines this as (0, True). |
| unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions(); |
| os << llvm::formatv(opClassRegionSpecTemplate, minRegionCount, |
| op.hasNoVariadicRegions() ? "True" : "False"); |
| } |
| |
| /// Emits named accessors to regions. |
| static void emitRegionAccessors(const Operator &op, raw_ostream &os) { |
| for (auto en : llvm::enumerate(op.getRegions())) { |
| const NamedRegion ®ion = en.value(); |
| if (region.name.empty()) |
| continue; |
| |
| assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && |
| "expected only the last region to be variadic"); |
| os << llvm::formatv(regionAccessorTemplate, sanitizeName(region.name), |
| std::to_string(en.index()) + |
| (region.isVariadic() ? ":" : "")); |
| } |
| } |
| |
| /// Emits bindings for a specific Op to the given output stream. |
| static void emitOpBindings(const Operator &op, |
| const AttributeClasses &attributeClasses, |
| raw_ostream &os) { |
| os << llvm::formatv(opClassTemplate, op.getCppClassName(), |
| op.getOperationName()); |
| |
| // Sized segments. |
| if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { |
| emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os); |
| } |
| if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) { |
| emitSegmentSpec(op, "RESULT", getNumResults, getResult, os); |
| } |
| |
| emitRegionAttributes(op, os); |
| emitDefaultOpBuilder(op, os); |
| emitOperandAccessors(op, os); |
| emitAttributeAccessors(op, attributeClasses, os); |
| emitResultAccessors(op, os); |
| emitRegionAccessors(op, os); |
| } |
| |
| /// Emits bindings for the dialect specified in the command line, including file |
| /// headers and utilities. Returns `false` on success to comply with Tablegen |
| /// registration requirements. |
| static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { |
| if (clDialectName.empty()) |
| llvm::PrintFatalError("dialect name not provided"); |
| |
| AttributeClasses attributeClasses; |
| constructAttributeMapping(records, attributeClasses); |
| |
| os << llvm::formatv(fileHeader, clDialectName.getValue()); |
| os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); |
| |
| for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { |
| Operator op(rec); |
| if (op.getDialectName() == clDialectName.getValue()) |
| emitOpBindings(op, attributeClasses, os); |
| } |
| return false; |
| } |
| |
| static GenRegistration |
| genPythonBindings("gen-python-op-bindings", |
| "Generate Python bindings for MLIR Ops", &emitAllOps); |