|  | # Shape Inference | 
|  |  | 
|  | Shape inference as discussed here is considered a specific instance of type | 
|  | inference for [ShapedType][ShapedType]. Type constraints are along (at least) | 
|  | three axis: 1) elemental type, 2) rank (including static or dynamic), 3) | 
|  | dimensions. While some operations have no compile time fixed shape (e.g., output | 
|  | shape is dictated by data) we could still have some knowledge of | 
|  | constraints/bounds in the system for that operation (e.g., the output of a | 
|  | `tf.where` is at most the size of the input data). That is, there are additional | 
|  | valuable constraints that could be captured even without full knowledge of the | 
|  | shape. | 
|  |  | 
|  | Type inference is currently modelled executionally for operation creation using the | 
|  | [`InferTypeOpInterface`][InferTypeOpInterface], while | 
|  | `InferShapedTypeOpInterface` is used to implement the shape and element type | 
|  | inference. The return type can often be deduced from the deduced return shape | 
|  | and elemental type (queryable from `InferShapedTypeOpInterface`) and so type | 
|  | inference for tensor types can be implemented with `InferShapedTypeOpInterface`. | 
|  |  | 
|  | [TOC] | 
|  |  | 
|  | ## Shape functions | 
|  |  | 
|  | The C++ interfaces are the base mechanism whereby shape inference is queried and | 
|  | executed, but not the intended way to specify shape constraints in general. | 
|  |  | 
|  | Initially the shape inference will be declaratively specified using: | 
|  |  | 
|  | *   Constraints on the operands of an operation directly. For example | 
|  | constraining the input type to be tensor/vector elements or that the | 
|  | elemental type be of a specific type (e.g., output of computing the size | 
|  | of a value is of elemental type `i1`) or class (e.g., float-like). | 
|  | *   Constraints across operands and results of an operation. | 
|  |  | 
|  | - For example, specifying equality constraints on type/constituents of a | 
|  | type (shape and elemental type) between operands and results (e.g., the | 
|  | output type of an add is the same as those of the input operands). | 
|  |  | 
|  | NOTE: The C++ shape functions are an intermediate step until the shape dialect | 
|  | is more full-fledged, at which point the C++ functions should become the | 
|  | exceptional case. | 
|  |  | 
|  | ## Testing | 
|  |  | 
|  | Shape inference is currently tested alongside type inference by | 
|  | `TestReturnTypeDriver` in the test dialect. This driver performs two checks: | 
|  |  | 
|  | 1.  Verification that the return types specified matches the inferred types. This | 
|  | explicit check will be removed and made part of Op verification instead. | 
|  | 2.  Test the creation of Ops without specifying the return type explicitly in | 
|  | function `testCreateFunctions` by creating new binary Ops (Op classes | 
|  | specified in `TestReturnTypeDriver`) using 1) all operands to | 
|  | `testCreateFunctions` as both operands, and 2) using combinations of input | 
|  | operands of the function. | 
|  |  | 
|  | ## Shape dialect | 
|  |  | 
|  | This section details the shape type inference dialect (`shape`). The initial | 
|  | focus will be on shape functions that describe shape functions could be used in | 
|  | runtime and compiler (for constructions of ops/refinement of shapes, reification | 
|  | of dynamic allocations for dialect including TF, TFLite, XLA & tensor compute | 
|  | dialect under discussion). | 
|  |  | 
|  | This will focus on the shape functions (e.g., determine the rank and dimensions | 
|  | of the output shape). As shown in the shaped container type, shape will be one | 
|  | of 3 components, the others being elemental type and attribute (which is | 
|  | currently left open with the intention of supporting extensions such as layouts | 
|  | or bounded shapes at a later point). This allows for decoupling of these: | 
|  |  | 
|  | *   Not all the information is needed for all analysis; | 
|  | *   Not all shape functions need to provide all the information (e.g., one could | 
|  | define a base class function that only populates element type but composes | 
|  | with the others); | 
|  | *   It allows reusing the constraints between, say, Tensor and Memref | 
|  | representation of an operation; | 
|  |  | 
|  | An argument could be made that these are metadata function instead of shape | 
|  | functions, with some considering shape and elemental types different and some considering them both as | 
|  | part of shape. But `shape function` is IMHO descriptive and metadata can span | 
|  | too large a range of potential uses/values. | 
|  |  | 
|  | ### Requirements | 
|  |  | 
|  | The requirements for the shape inference functions are determined by the | 
|  | requirements of shape inference, but we believe the requirements below still | 
|  | allow freedom to consider different shape inference approaches and so we do not | 
|  | impose a particular shape inference approach here. | 
|  |  | 
|  | #### Shape inference functions | 
|  |  | 
|  | *   **Expressiveness** shape functions need to support programs where tensors | 
|  | have shapes that are not known statically (for example, `tensor<16x?xf32>` | 
|  | or `tensor<*xf32>*`); | 
|  | *   **Shape error detection** Many operations will have constraints on their | 
|  | operands. If the constraints are not satisfied or cannot be determined if | 
|  | satisfied statically, then a runtime check/assertion could be generated. | 
|  |  | 
|  | *   This also aligns with the requirement that the shape function description | 
|  | should be usable by both the compiler and runtime. | 
|  | *   Shape error functions should be easy to understand, at least what | 
|  | constraint of the operation is violated. This also requires that shape | 
|  | function error messages should be configurable by the author of the | 
|  | shape function (e.g., the author would be able to give the semantic | 
|  | constraint invalidated rather the low-level check that failed). | 
|  | *   The static analysis may be used to eliminate run-time checks that are | 
|  | guaranteed to pass. | 
|  | *   Ideally all would eventually (see section | 
|  | [Inlining shape checking](#inline)) be elided. | 
|  | *   Only reporting errors which are guaranteed to occur at runtime. If an error is only | 
|  | possible (rather than guaranteed) then we use a runtime assertion to fail and produce an error | 
|  | message with the invariant violated. | 
|  |  | 
|  | *   Shape functions usable by compiler and runtime. | 
|  |  | 
|  | *   This does not mean the exact same C++ function, but rather the | 
|  | description should be consumable by either. | 
|  | *   Shape function description should not be constrained by either runtime | 
|  | or compiler's type system to handle types only used for analysis. That | 
|  | is, these two type systems differ and both should be supported, but the | 
|  | intersection of the two should not be required. As a particular example, | 
|  | if a compiler only wants to differentiate exact shapes vs dynamic | 
|  | shapes, then it need not consider a more generic shape lattice even | 
|  | though the shape description supports it. | 
|  |  | 
|  | *   Declarative (e.g., analyzable at compile time, possible to generate | 
|  | different versions for different use cases) | 
|  |  | 
|  | *   This may not strictly be a requirement, but a way to handle the former: | 
|  | a declarative specification could be reused by both while avoiding a | 
|  | need to map to or from a 3rd representation given these two systems | 
|  | have/and will have different types. | 
|  |  | 
|  | *   Shape inference functions are expressible at runtime | 
|  |  | 
|  | *   User can define a shape function for a new operation dynamically at runtime, | 
|  | this allows for vendors to describe an operation and shape function | 
|  | dynamically. | 
|  |  | 
|  | This requirement is on the wishlist. | 
|  |  | 
|  | *   Doesn't require graph-wide shape information (e.g., only require local | 
|  | information) | 
|  |  | 
|  | *   Shape functions should be cheap to invoke on each kernel launch. | 
|  | *   Shape function can be dictated by arguments (operands, attributes and regions) | 
|  | only (e.g., same operands as the corresponding operation could be | 
|  | constructed & invoked with). | 
|  | *   Shape information that needs higher-level/graph information should use | 
|  | richer types (e.g., `TensorList<F32>`); | 
|  | *   The function should be invocable before/while constructing an op (e.g., | 
|  | can't rely on the op being constructed). | 
|  |  | 
|  | *   Shape functions should be pure functions. | 
|  |  | 
|  | *   Should support functions whose type is only known dynamically (e.g., | 
|  | `read_from_file` op) | 
|  |  | 
|  | *   Without needing to invoke the op (e.g., reading a file once for | 
|  | determining the shape & then post to be able to actually consume the | 
|  | output of the file). | 
|  |  | 
|  | *   The shape function operation dialect should be interoperable with non-shape function dialect operations. | 
|  |  | 
|  | *   There may be a common set of operations that satisfy most uses (e.g., merge, | 
|  | equal_type, arithmetic expressions, slice, concat, pattern matching on | 
|  | attributes such as padding etc.) that will be discovered and could cover | 
|  | a large percentage of the use cases. Among these there will be some | 
|  | which carry extra semantic info that could be used for symbolic | 
|  | constraints (e.g., checking equality of two dimensions resulting in | 
|  | setting an equality constraint) and higher-order interpretation for | 
|  | constraint solving. | 
|  |  | 
|  | It is therefore beneficial (but not required) to reuse operations, | 
|  | especially as for statically known shapes, arbitrary arithmetic | 
|  | computations could still be performed. This means that the computations | 
|  | performed statically may or may not be supported by an arbitrary solver, | 
|  | but would still be allowed. | 
|  |  | 
|  | *   The shape function should be expandable such that symbolic equality and | 
|  | upper bound constraints (say) could be represented and may be propagated by | 
|  | shape inference. | 
|  |  | 
|  | *   E.g., the shape functions may contain more information that is only | 
|  | useful when used from shape inference; | 
|  |  | 
|  | *   Shape functions are allowed to fail and report an error. The error reporting | 
|  | should report the location of the operation that failed with, where | 
|  | possible, a user actionable error message. | 
|  |  | 
|  | *   These failures could become inlined and become runtime failures with | 
|  | runtime values and error messages. | 
|  | *   Reporting errors should be optional. E.g., The same function | 
|  | may be used as to query validity without reporting an error. | 
|  |  | 
|  | #### Non-goals | 
|  |  | 
|  | 1.  The shape dialect is an IR representations and not a programming language; | 
|  | *   While the functions should be readable, it doesn't carry the | 
|  | conveniences of a programming language. Deciding how people write these | 
|  | things, e.g. a mini dsl, a C++ API that generates them, extracting them | 
|  | programmatically from `SetShapeFn` calls, etc., is still TBD. | 
|  | 1.  Describe the shape inference approach that will use the shape functions; | 
|  | *   The goal is that the shape functions and the constraints one could | 
|  | obtain from them are general enough that they would be useful for | 
|  | various analysis. But whether we follow very simple (e.g., only fully | 
|  | static information is used for shape output, unranked for everything | 
|  | else) to very advance (e.g., expression trees of symbolic constants) can | 
|  | be evaluated independently of this proposal and with concrete benefit | 
|  | analysis. | 
|  | 1.  Describe the approach whereby error messages will be generated; | 
|  | *   While the shape functions will be able to emit errors optionally, it | 
|  | will be possible to dictate when they emit an error. This enables | 
|  | deciding whether or which error to emit: there have been proposals in | 
|  | the literature that the iteration order for shape inference affect the | 
|  | quality of the error message produced, and the shape functions do not | 
|  | mandate that. | 
|  | 1.  Flow sensitive shape functions; | 
|  | *   To enable scalable/cheap shape inference, the shape functions do not | 
|  | intend to provide flow sensitive information. This facility could | 
|  | potentially be built as part of some higher order analysis that reuse | 
|  | the shape functions/constraints due to the shape functions. | 
|  | 1.  All static functions are usable for dynamic/unknown shapes; | 
|  | *   More involved computations can be performed with statically known shapes | 
|  | than what can be sensibly analyzed with unknown/symbolic variables. | 
|  |  | 
|  | ### Discussion | 
|  |  | 
|  | #### Inline shape inference checks {#inline} | 
|  |  | 
|  | Shape functions should be lowerable to runtime checks for validity. E.g. verify | 
|  | as much as possible statically, but enable generating instructions to compute the | 
|  | shape dynamically and or falling back to runtime checks for attributes not | 
|  | verifiable at compile time. These checks inserted should ideally only check that | 
|  | which could not have been verified statically. | 
|  |  | 
|  | These inlined calls could interfere with optimization patterns/passes (e.g., | 
|  | shape inference should not insert constructs that interfere with optimization | 
|  | patterns) and so could be delayed until later (with another round of | 
|  | optimizations, constant folding, CSE, etc., that should remove redundant runtime | 
|  | operations). | 
|  |  | 
|  | ### Possibly Asked Questions | 
|  |  | 
|  | #### What about ODS specifications of operations? | 
|  |  | 
|  | In ODS we have been recording the constraints for the operands & attributes of | 
|  | an operation. Where these are sufficient to constrain the output shape (e.g., | 
|  | `SameOperandAndResultType` or broadcastable) we should generate the shape | 
|  | function from those. Where not, an explicit shape function should be specified | 
|  | (spelling TBD but currently considering using the MLIR textual form as | 
|  | serialization approach). | 
|  |  | 
|  | #### Why not extract the shape function from reference implementation? | 
|  |  | 
|  | This could be done in future! The extracted shape function would use the shape | 
|  | inference dialect, so we are starting there. Especially for operations described in a | 
|  | structured way, one could autogenerate the shape function. | 
|  |  | 
|  | #### How/in what language will the shape functions be authored? | 
|  |  | 
|  | TBD. open to many approaches and suggestions, starting on the IR produced by | 
|  | whatever language is the priority of this proposal. | 
|  |  | 
|  | #### What shape inference approach is being suggested here? | 
|  |  | 
|  | None. There are multiple different shape inference approaches that we could | 
|  | layer on top of these. From the most basic (always return unranked), to more | 
|  | useful (return fixed shape for constant inputs/arguments) to the more advanced | 
|  | (create logical conjunctions of algebraic statements between symbolic named | 
|  | values). | 
|  |  | 
|  | ### Open points | 
|  |  | 
|  | 1.  Should shape functions that produce dynamic outputs given all statically | 
|  | shaped inputs be marked specially? E.g., read from file. | 
|  |  | 
|  | TODO: Add examples here. | 
|  |  | 
|  | ## WIP/Future considerations | 
|  |  | 
|  | Shape functions are determined by attributes and could be arbitrarily | 
|  | complicated with a wide-range of specification possibilities. Equality | 
|  | relationships are common (e.g., the elemental type of the output matches the | 
|  | primitive type of the inputs, both inputs have exactly the same type [primitive | 
|  | type and shape]) and so these should be easy to specify. Algebraic relationships | 
|  | would also be common (e.g., a concat of `[n,m]` and `[n,m]` matrix along axis 0 | 
|  | is `[n+n, m]` matrix), while some ops only have defined shapes under certain | 
|  | cases (e.g., matrix multiplication of `[a,b]` and `[c,d]` is only defined if `b | 
|  | == c`). | 
|  |  | 
|  | Instead of specifying an additional mechanism to specify a shape transfer | 
|  | function, the reference implementation of the operation will be used to derive | 
|  | the shape function. The reference implementation is general and can support the | 
|  | arbitrary computations needed to specify output shapes. | 
|  |  | 
|  | [InferTypeOpInterface]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/Interfaces/InferTypeOpInterface.td | 
|  | [ShapedType]: https://github.com/llvm/llvm-project/tree/main/mlir/include/mlir/IR/BuiltinTypes.h |