| //===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This is the base operation definition file. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef OP_BASE |
| #define OP_BASE |
| |
| include "mlir/IR/Constraints.td" |
| include "mlir/IR/DialectBase.td" |
| include "mlir/IR/Interfaces.td" |
| include "mlir/IR/Properties.td" |
| include "mlir/IR/Traits.td" |
| include "mlir/IR/Utils.td" |
| include "mlir/IR/AttrTypeBase.td" |
| |
| //===----------------------------------------------------------------------===// |
| // OpTrait definitions |
| //===----------------------------------------------------------------------===// |
| |
| // A trait that describes the structure of operation will be marked with |
| // `StructuralOpTrait` and they will be verified first. |
| class StructuralOpTrait; |
| |
| // These classes are used to define operation specific traits. |
| |
| // Specify op specific declarations and definitions in `extraOpDeclaration` |
| // and `extraOpDefinition` template arguments. |
| class NativeOpTrait<string name, list<Trait> traits = [], |
| code extraOpDeclaration = [{}], |
| code extraOpDefinition = [{}]> |
| : NativeTrait<name, "Op", extraOpDeclaration, extraOpDefinition> { |
| // Specify the list of traits that need to be verified before the verification |
| // of this NativeOpTrait. |
| list<Trait> dependentTraits = traits; |
| } |
| class ParamNativeOpTrait<string prop, string params, |
| list<Trait> traits = []> |
| : ParamNativeTrait<prop, params, "Op"> { |
| // Specify the list of traits that need to be verified before the verification |
| // of this ParamNativeOpTrait. |
| list<Trait> dependentTraits = traits; |
| } |
| class GenInternalOpTrait<string prop, list<Trait> traits = []> |
| : GenInternalTrait<prop, "Op"> { |
| // Specify the list of traits that need to be verified before the verification |
| // of this GenInternalOpTrait. |
| list<Trait> dependentTraits = traits; |
| } |
| class PredOpTrait<string descr, Pred pred, list<Trait> traits = []> |
| : PredTrait<descr, pred> { |
| // Specify the list of traits that need to be verified before the verification |
| // of this PredOpTrait. |
| list<Trait> dependentTraits = traits; |
| } |
| |
| // Op defines an affine scope. |
| def AffineScope : NativeOpTrait<"AffineScope">; |
| // Op defines an automatic allocation scope. |
| def AutomaticAllocationScope : |
| NativeOpTrait<"AutomaticAllocationScope">; |
| // Op supports operand broadcast behavior. |
| def ResultsBroadcastableShape : |
| NativeOpTrait<"ResultsBroadcastableShape">; |
| // X op Y == Y op X |
| def Commutative : NativeOpTrait<"IsCommutative">; |
| // op op X == op X (unary) / X op X == X (binary) |
| // FIXME: Idempotent should depend on SameOperandsAndResultType |
| def Idempotent : NativeOpTrait<"IsIdempotent">; |
| // op op X == X |
| // FIXME: Involution should depend on SameOperandsAndResultType |
| def Involution : NativeOpTrait<"IsInvolution">; |
| // Op behaves like a constant. |
| def ConstantLike : NativeOpTrait<"ConstantLike">; |
| // Op is isolated from above. |
| def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; |
| // Op results are float or vectors/tensors thereof. |
| def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; |
| // Op has the same operand type. |
| def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; |
| // Op has same shape for all operands. |
| def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; |
| // Op has same operand and result shape. |
| def SameOperandsAndResultShape : |
| NativeOpTrait<"SameOperandsAndResultShape">; |
| // Op has the same element type (or type itself, if scalar) for all operands. |
| def SameOperandsElementType : |
| NativeOpTrait<"SameOperandsElementType">; |
| // Op has the same operand and result element type (or type itself, if scalar). |
| def SameOperandsAndResultElementType : |
| NativeOpTrait<"SameOperandsAndResultElementType">; |
| // Op is a terminator. |
| def Terminator : NativeOpTrait<"IsTerminator">; |
| // Op can be safely normalized in the presence of MemRefs with |
| // non-identity maps. |
| def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; |
| // Op is elementwise on tensor/vector operands and results. |
| def Elementwise : NativeOpTrait<"Elementwise">; |
| // Elementwise op can be applied to scalars instead tensor/vector operands. |
| def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>; |
| // Elementwise op can be applied to all-vector operands. |
| def Vectorizable : NativeOpTrait<"Vectorizable", [Elementwise]>; |
| // Elementwise op can be applied to all-tensor operands. |
| def Tensorizable : NativeOpTrait<"Tensorizable", [Elementwise]>; |
| |
| // Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and |
| // `Tensorizable` for convenience. |
| def ElementwiseMappable : TraitList<[ |
| Elementwise, |
| Scalarizable, |
| Vectorizable, |
| Tensorizable, |
| ]>; |
| |
| // Op's regions have a single block. |
| def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait; |
| |
| class SingleBlockImplicitTerminatorImpl<string op> |
| : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op, [SingleBlock]>, |
| StructuralOpTrait; |
| |
| // Op's regions have a single block with the specified terminator. |
| class SingleBlockImplicitTerminator<string op> |
| : TraitList<[SingleBlock, SingleBlockImplicitTerminatorImpl<op>]>; |
| |
| // Op's regions don't have terminator. |
| def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait; |
| |
| // Op's parent operation is the provided one. |
| class HasParent<string op> |
| : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait; |
| |
| class ParentOneOf<list<string> ops> |
| : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>, |
| StructuralOpTrait; |
| |
| // Op result type is derived from the first attribute. If the attribute is an |
| // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the |
| // attribute content is used. |
| def FirstAttrDerivedResultType : |
| GenInternalOpTrait<"FirstAttrDerivedResultType">; |
| |
| // TODO: Turn the following into normal traits and generate verification for |
| // them. |
| |
| // All variadic operands of the op have the same number of values. |
| // A variadic operand contains an array of values whose array size is only |
| // known at runtime. This trait requires all variadic operands of an op |
| // to have the same array size. |
| def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; |
| // All variadic results of the op have the same number of values. |
| // A variadic result contains an array of values whose array size is only |
| // known at runtime. This trait requires all variadic results of an op |
| // to have the same array size. |
| def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; |
| |
| // Uses an attribute named `operandSegmentSizes` to specify how many actual |
| // operand each ODS-declared operand (variadic or not) corresponds to. |
| // This trait is used for ops that have multiple variadic operands but do |
| // not know statically their size relationship. The attribute must be a 1D |
| // vector that has the same number of elements as the number of ODS declared |
| // operands. That means even if some operands are non-variadic, the attribute |
| // still need to have an element for its size, which is always 1. |
| def AttrSizedOperandSegments : |
| NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait; |
| // Similar to AttrSizedOperandSegments, but used for results. The attribute |
| // should be named as `resultSegmentSizes`. |
| def AttrSizedResultSegments : |
| NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait; |
| |
| // Op attached regions have no arguments |
| def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait; |
| |
| //===----------------------------------------------------------------------===// |
| // Successor definitions |
| //===----------------------------------------------------------------------===// |
| |
| class Successor<Pred condition, string descr = ""> : |
| SuccessorConstraint<condition, descr>; |
| |
| // Any successor. |
| def AnySuccessor : Successor<?, "any successor">; |
| |
| // A variadic successor constraint. It expands to zero or more of the base |
| // successor. |
| class VariadicSuccessor<Successor successor> |
| : Successor<successor.predicate, successor.summary>; |
| |
| //===----------------------------------------------------------------------===// |
| // Region definitions |
| //===----------------------------------------------------------------------===// |
| |
| class Region<Pred condition, string descr = ""> : |
| RegionConstraint<condition, descr>; |
| |
| // Any region. |
| def AnyRegion : Region<CPred<"true">, "any region">; |
| |
| // A region with the given number of blocks. |
| class SizedRegion<int numBlocks> : Region< |
| CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, |
| "region with " # numBlocks # " blocks">; |
| |
| // A region with at least the given number of blocks. |
| class MinSizedRegion<int numBlocks> : Region< |
| CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">, |
| "region with at least " # numBlocks # " blocks">; |
| |
| // A region with at most the given number of blocks. |
| class MaxSizedRegion<int numBlocks> : Region< |
| CPred<"::llvm::hasNItemsOrLess($_self, " # numBlocks # ")">, |
| "region with at most " # numBlocks # " blocks">; |
| |
| // A variadic region constraint. It expands to zero or more of the base region. |
| class VariadicRegion<Region region> |
| : Region<region.predicate, region.summary>; |
| |
| //===----------------------------------------------------------------------===// |
| // Markers |
| //===----------------------------------------------------------------------===// |
| |
| // Marker used to identify the region list. |
| def region; |
| |
| // Marker used to identify the successor list. |
| def successor; |
| |
| //===----------------------------------------------------------------------===// |
| // Op definitions |
| //===----------------------------------------------------------------------===// |
| |
| // Class for defining a custom builder. |
| // |
| // TableGen generates several generic builders for each op by default (see |
| // comment in the `Op` class). If the default generated ones cannot cover |
| // some use case, custom builders can be defined using instances of this class. |
| // |
| // The signature of the builder is always |
| // |
| // ```c++ |
| // static void build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, |
| // <other-parameters>...) { |
| // <body>... |
| // } |
| // ``` |
| // |
| // To define a custom builder, the parameter list (*excluding* the |
| // `OpBuilder &builder, OperationState &state` part) and body should be passed |
| // in as separate template arguments to this class. The parameter list is a |
| // TableGen DAG with `ins` operation with named arguments, which has either: |
| // - string initializers ("Type":$name) to represent a typed parameter, or |
| // - CArg-typed initializers (CArg<"Type", "default">:$name) to represent a |
| // typed parameter that may have a default value. |
| // The type string is used verbatim to produce code and, therefore, must be a |
| // valid C++ type. It is used inside the C++ namespace of the parent Op's |
| // dialect; explicit namespace qualification like `::mlir` may be necessary if |
| // Ops are not placed inside the `mlir` namespace. The default value string is |
| // used verbatim to produce code and must be a valid C++ initializer the given |
| // type. For example, the following signature specification |
| // |
| // ``` |
| // OpBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)> |
| // ``` |
| // |
| // has an integer parameter and a float parameter with a default value. |
| // |
| // If an empty string is passed in for `body`, then *only* the builder |
| // declaration will be generated; this provides a way to define complicated |
| // builders entirely in C++. |
| class OpBuilder<dag p, code b = ""> { |
| dag dagParams = p; |
| code body = b; |
| } |
| |
| // OpBuilder like the above, but the emitted 'build' method is marked as |
| // deprecated in C++. Use of it will emit a warning by the C++ compiler |
| // with the given reason. |
| class DeprecatedOpBuilder<string reason, dag p, code b = ""> |
| : OpBuilder<p, b>, CppDeprecated<reason>; |
| |
| // A base decorator class that may optionally be added to OpVariables. |
| class OpVariableDecorator; |
| |
| // Class for providing additional information on the variables, i.e. arguments |
| // and results, of an operation. |
| class OpVariable<Constraint varConstraint, string desc = "", |
| list<OpVariableDecorator> varDecorators = []> { |
| // The constraint, either attribute or type, of the argument. |
| Constraint constraint = varConstraint; |
| |
| // One-line human-readable description of the argument. |
| string summary = desc; |
| |
| // The list of decorators for this variable, e.g. side effects. |
| list<OpVariableDecorator> decorators = varDecorators; |
| } |
| class Arg<Constraint constraint, string desc = "", |
| list<OpVariableDecorator> decorators = []> |
| : OpVariable<constraint, desc, decorators>; |
| class Res<Constraint constraint, string desc = "", |
| list<OpVariableDecorator> decorators = []> |
| : OpVariable<constraint, desc, decorators>; |
| |
| // Marker to group ops together for documentation purposes. |
| class OpDocGroup { |
| // Single line summary of the group of ops. |
| string summary; |
| |
| // Longer description of documentation group. |
| string description; |
| } |
| |
| // Base class for all ops. |
| class Op<Dialect dialect, string mnemonic, list<Trait> props = []> { |
| // The dialect of the op. |
| Dialect opDialect = dialect; |
| |
| // The mnemonic of the op. |
| string opName = mnemonic; |
| |
| // The C++ namespace to use for this op. |
| string cppNamespace = dialect.cppNamespace; |
| |
| // One-line human-readable description of what the op does. |
| string summary = ""; |
| |
| // Additional, longer human-readable description of what the op does. |
| string description = ""; |
| |
| // Optional. The group of ops this op is part of. |
| OpDocGroup opDocGroup = ?; |
| |
| // Dag containing the arguments of the op. Default to 0 arguments. |
| dag arguments = (ins); |
| |
| // The list of results of the op. Default to 0 results. |
| dag results = (outs); |
| |
| // The list of regions of the op. Default to 0 regions. |
| dag regions = (region); |
| |
| // The list of successors of the op. Default to 0 successors. |
| dag successors = (successor); |
| |
| // Attribute getters can be added to the op by adding an Attr member |
| // with the name and type of the attribute. E.g., adding int attribute |
| // with name "value" and type "i32": |
| // I32Attr value; |
| |
| // Define the hooks used for building, parsing, printing, verification. |
| |
| // Custom builder. |
| // In addition to the custom builder provided here, and unless |
| // skipDefaultBuilders is set, two default builders are generated, with the |
| // following signatures: |
| // |
| // ```c++ |
| // static void build(OpBuilder &, OperationState &odsState, |
| // Type <result0-name>, Type <result1-name>, ..., |
| // Value <arg0-name>, Value <arg1-name>, ..., |
| // Attribute <attr0-name>, Attribute <attr1-name>, ...); |
| // ``` |
| // * where the attributes follow the same declaration order as in the op. |
| // |
| // ```c++ |
| // static void build(OpBuilder &, OperationState &odsState, |
| // TypeRange resultTypes, |
| // ValueRange operands, |
| // ArrayRef<NamedAttribute> attributes); |
| // ``` |
| list<OpBuilder> builders = ?; |
| |
| // Avoid generating default build functions. Custom builders must be |
| // provided. |
| bit skipDefaultBuilders = 0; |
| |
| // Custom assembly format. |
| /// This field corresponds to a declarative description of the assembly format |
| /// for this operation. If populated, the `hasCustomAssemblyFormat` field is |
| /// ignored. |
| string assemblyFormat = ?; |
| /// This field indicates that the operation has a custom assembly format |
| /// implemented in C++. When set to `1` a `parse` and `print` method are generated |
| /// on the operation class. The operation should implement these methods to |
| /// support the custom format of the operation. The methods have the form: |
| /// * ParseResult parse(OpAsmParser &parser, OperationState &result) |
| /// * void print(OpAsmPrinter &p) |
| bit hasCustomAssemblyFormat = 0; |
| |
| // A bit indicating if the operation has additional invariants that need to |
| // verified (aside from those verified by other ODS constructs). If set to `1`, |
| // an additional `LogicalResult verify()` declaration will be generated on the |
| // operation class. The operation should implement this method and verify the |
| // additional necessary invariants. This verifier shouldn't access any nested |
| // operations because those operations may ill-formed. Use the |
| // `hasRegionVerifier` below instead. |
| bit hasVerifier = 0; |
| |
| // A bit indicating if the operation has additional invariants that need to |
| // verified and which associate with regions (aside from those verified by the |
| // traits). If set to `1`, an additional `LogicalResult verifyRegions()` |
| // declaration will be generated on the operation class. The operation should |
| // implement this method and verify the additional necessary invariants |
| // associated with regions. Note that this method is invoked after all the |
| // region ops are verified. |
| bit hasRegionVerifier = 0; |
| |
| // Whether this op has associated canonicalization patterns. |
| bit hasCanonicalizer = 0; |
| |
| // Whether this op has a static "canonicalize" method to perform "match and |
| // rewrite patterns". |
| bit hasCanonicalizeMethod = 0; |
| |
| // Whether this op has a folder. |
| bit hasFolder = 0; |
| |
| // Whether to let ops implement their custom `readProperties` and |
| // `writeProperties` methods to emit bytecode. |
| bit useCustomPropertiesEncoding = 0; |
| |
| // Op traits. |
| // Note: The list of traits will be uniqued by ODS. |
| list<Trait> traits = props; |
| |
| // Additional code that will be added to the public part of the generated |
| // C++ code of the op declaration. |
| code extraClassDeclaration = ?; |
| |
| // Additional code that will be added to the generated source file. The |
| // generated code is placed inside the op's C++ namespace. `$cppClass` is |
| // replaced by the op's C++ class name. |
| code extraClassDefinition = ?; |
| } |
| |
| // The arguments of an op. |
| class Arguments<dag args> { |
| dag arguments = args; |
| } |
| |
| // The results of an op. |
| class Results<dag rets> { |
| dag results = rets; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Common promised interface constraints |
| //===----------------------------------------------------------------------===// |
| |
| // This constrait represents a promise or an implementation of an attr interface. |
| class PromisedAttrInterface<AttrInterface interface> : AttrConstraint< |
| CPred<"$_self.hasPromiseOrImplementsInterface<" # |
| !if(!empty(interface.cppNamespace), |
| "", |
| interface.cppNamespace # "::") # interface.cppInterfaceName #">()">, |
| "promising or implementing the `" # interface.cppInterfaceName # "` attr interface">; |
| |
| // This predicate checks if the type promises or implementats a type interface. |
| class HasPromiseOrImplementsTypeInterface<TypeInterface interface> : |
| CPred<"$_self.hasPromiseOrImplementsInterface<" # |
| !if(!empty(interface.cppNamespace), |
| "", |
| interface.cppNamespace # "::") # interface.cppInterfaceName #">()">; |
| |
| // This constrait represents a promise or an implementation of a type interface. |
| class PromisedTypeInterface<TypeInterface interface> : TypeConstraint< |
| HasPromiseOrImplementsTypeInterface<interface>, |
| "promising or implementing the `" # interface.cppInterfaceName # "` type interface">; |
| |
| //===----------------------------------------------------------------------===// |
| // Common op type constraints |
| //===----------------------------------------------------------------------===// |
| |
| // These traits are for verifying properties of an op that require knowledge of |
| // multiple arguments or results. For verifying properties of a single argument |
| // or result, prefer operand type constraints. |
| |
| // These traits often require including "mlir/IR/TypeUtilities.h". |
| |
| // TODO: Improve the autogenerated error messages. |
| |
| class Rank<string name> : |
| StrFunc<"::llvm::cast<::mlir::ShapedType>($" # name # ".getType()).getRank()">; |
| |
| class Shape<string name> : |
| StrFunc<"::llvm::cast<::mlir::ShapedType>($" # name # ".getType()).getShape()">; |
| |
| class ElementCount<string name> : |
| StrFunc<"llvm::cast<::mlir::ShapedType>($" # name # ".getType())" |
| ".getNumElements()">; |
| |
| class ElementType<string name> : StrFunc<"getElementTypeOrSelf($" # name # ")">; |
| |
| class AnyPred<list<string> values> : |
| CPred<!if(!lt(!size(values), 1), |
| "false", |
| !foldl("(" # !head(values) # ")", !tail(values), acc, v, |
| acc # " || (" # v # ")"))>; |
| |
| class AllMatchPred<list<string> values> : |
| CPred<!if(!lt(!size(values), 2), |
| "true", |
| !foldl("(" # !head(values) # ")", !tail(values), acc, v, |
| acc # " == (" # v # ") && (" # v # ")") |
| # " == (" # !head(values) # ")")>; |
| |
| class AllMatch<list<string> values, string summary> : |
| PredOpTrait<summary, AllMatchPred<values>>; |
| |
| // TODO: Only works for non-variadic. |
| class AllMatchSameOperatorPred<list<string> names, string operator> : |
| AllMatchPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>; |
| |
| class AllMatchSameOperatorTrait<list<string> names, string operator, |
| string summary> : |
| PredOpTrait< |
| "all of {" # !interleave(names, ", ") # "} have same " # summary, |
| AllMatchSameOperatorPred<names, operator>> { |
| list<string> values = names; |
| } |
| |
| class AnyMatchOperatorPred<list<string> names, string operator> : |
| AnyPred<!foreach(n, names, !subst("$_self", "$" # n, operator))>; |
| |
| class AnyMatchOperatorTrait<list<string> names, string operator, |
| string summary> : |
| PredOpTrait< |
| "any of {" # !interleave(names, ", ") # "} has " # summary, |
| AnyMatchOperatorPred<names, operator>> { |
| list<string> values = names; |
| } |
| |
| class AllElementCountsMatch<list<string> names> : |
| AllMatchSameOperatorTrait<names, ElementCount<"_self">.result, |
| "element count">; |
| |
| class AllElementTypesMatch<list<string> names> : |
| AllMatchSameOperatorTrait<names, ElementType<"_self">.result, |
| "element type">; |
| |
| class AllRanksMatch<list<string> names> : |
| AllMatchSameOperatorTrait<names, Rank<"_self">.result, "rank">; |
| |
| class AllShapesMatch<list<string> names> : |
| AllMatchSameOperatorTrait<names, Shape<"_self">.result, "shape">; |
| |
| class AllTypesMatch<list<string> names> : |
| AllMatchSameOperatorTrait<names, "$_self.getType()", "type">; |
| |
| // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`. |
| // An optional comparator function may be provided that changes the above form |
| // into: `comparator(transform(lhs.getType()), rhs.getType())`. |
| class TypesMatchWith<string summary, string lhsArg, string rhsArg, |
| string transform, string comparator = "std::equal_to<>()"> |
| : PredOpTrait<summary, CPred< |
| comparator # "(" # |
| !subst("$_self", "$" # lhsArg # ".getType()", transform) # |
| ", $" # rhsArg # ".getType())">> { |
| string lhs = lhsArg; |
| string rhs = rhsArg; |
| string transformer = transform; |
| } |
| |
| // The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional |
| // and not present returns success. |
| class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg, |
| string transform, string comparator = "std::equal_to<>()"> |
| : TypesMatchWith<summary, lhsArg, rhsArg, transform, |
| "!get" # snakeCaseToCamelCase<lhsArg>.ret # "()" |
| # " || !get" # snakeCaseToCamelCase<rhsArg>.ret # "() || " # comparator>; |
| |
| // Special variant of `TypesMatchWith` that provides a comparator suitable for |
| // ranged arguments. |
| class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg, |
| string transform> |
| : TypesMatchWith<summary, lhsArg, rhsArg, transform, "llvm::equal">; |
| |
| // Type Constraint operand `idx`'s Element type is `type`. |
| class TCopVTEtIs<int idx, Type type> : And<[ |
| CPred<"$_op.getNumOperands() > " # idx>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # idx # ").getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))", |
| type.predicate>]>; |
| |
| // Predicate to verify that a named argument or result's element type matches a |
| // given type. |
| class TypeIsPred<string name, Type type> : |
| SubstLeaves<"$_self", "$" # name # ".getType()", type.predicate>; |
| class TypeIs<string name, Type type> : PredOpTrait< |
| "'" # name # "' is " # type.summary, TypeIsPred<name, type>>; |
| |
| // Predicate to verify that a named argument or result's element type matches a |
| // given type. |
| class ElementTypeIsPred<string name, Type type> : And<[ |
| SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>, |
| SubstLeaves<"$_self", "getElementTypeOrSelf($" # name # ")", |
| type.predicate>]>; |
| class ElementTypeIs<string name, Type type> : PredOpTrait< |
| "'" # name # "' is " # type.summary, ElementTypeIsPred<name, type>>; |
| |
| // Predicate to verify that the i'th operand and the j'th operand have the same |
| // elemental type. |
| // Type Constraint operand `i`'s Element type is Same As operand `j`'s Element |
| // type. |
| class TCopVTEtIsSameAs<int i, int j> : And<[ |
| CPred<"$_op.getNumOperands() > " # !if(!gt(i,j),i,j)>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # i # ").getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()", |
| IsShapedTypePred>, |
| CPred<"::mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == " |
| "::mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>; |
| |
| // Predicate to verify that the i'th result and the j'th operand exist and has |
| // shaped types. |
| class TCOpResIsShapedTypePred<int i, int j> : And<[ |
| CPred<"$_op.getNumResults() > " # i>, |
| CPred<"$_op.getNumOperands() > " # j>, |
| SubstLeaves<"$_self", "$_op.getResult(" # i # ").getType()", |
| IsShapedTypePred>, |
| SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()", |
| IsShapedTypePred>]>; |
| |
| // Predicate to verify that the i'th result and the j'th operand have the same |
| // type. |
| class TCresIsSameAsOpBase<int i, int j> : |
| CPred<"$_op.getResult(" # i # ").getType() == " |
| "$_op.getOperand(" # j # ").getType()">; |
| |
| // Basic Predicate to verify that the i'th result and the j'th operand have the |
| // same elemental type. |
| class TCresVTEtIsSameAsOpBase<int i, int j> : |
| CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")) == " |
| "getElementTypeOrSelf($_op.getOperand(" # j # "))">; |
| |
| // Predicate to verify that the i'th result and the j'th operand have the same |
| // elemental type. |
| // Type Constraint result`i`'s Element type is Same As Operand `j`'s Element |
| // type. |
| class TCresVTEtIsSameAsOp<int i, int j> : And<[ |
| TCOpResIsShapedTypePred<i, j>, |
| TCresVTEtIsSameAsOpBase<i, j>]>; |
| |
| // Predicate to verify that the opId'th operand can be broadcasted to the type |
| // of the resId'th result. |
| class TCOpIsBroadcastableToRes<int opId, int resId> : And<[ |
| TCOpResIsShapedTypePred<opId, resId>, |
| CPred<"::mlir::OpTrait::util::getBroadcastedType(" |
| "$_op.getOperand(" # opId # ").getType(), " |
| "$_op.getResult(" # resId # ").getType())">]>; |
| |
| // Predicate to verify that all the operands at the given `indices` |
| // have the same element type. |
| // Type Constraint operands' Element type are all Same At the given `indices`. |
| // We query the operands' types into a list and check they are all the same. |
| // Precondition: |
| // 1) all operands involved are of shaped type and |
| // 2) the indices are not out of range. |
| class TCopVTEtAreSameAt<list<int> indices> : CPred< |
| "::llvm::all_equal(::llvm::map_range(" |
| "::mlir::ArrayRef<unsigned>({" # !interleave(indices, ", ") # "}), " |
| "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); " |
| "}))">; |
| |
| class AnyScalarTypeMatch<list<string> names> : |
| AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)", |
| "scalar type">; |
| |
| class ScalarConditionOrMatchingShape<list<string> names> : |
| PredOpTrait< |
| !head(names) # " is scalar or has matching shape", |
| Or<[AnyScalarTypeMatch<[!head(names)]>.predicate, |
| AllShapesMatch<names>.predicate]>> { |
| list<string> values = names; |
| } |
| |
| #endif // OP_BASE |