Shape inference as discussed here is considered a specific instance of type inference for ShapedType. Type constraints are along (at least) three axis: 1) elemental type, 2) rank (including static or dynamic), 3) dimensions. While some operations have no compile time fixed shape (e.g., output shape is dictated by data) we could still have some knowledge of constraints/bounds in the system for that operation (e.g., the output of a tf.where
is at most the size of the input data). That is, there are additional valuable constraints that could be captured even without full knowledge of the shape.
Type inference is currently modelled executionally for operation creation using the InferTypeOpInterface
, while InferShapedTypeOpInterface
is used to implement the shape and element type inference. The return type can often be deduced from the deduced return shape and elemental type (queryable from InferShapedTypeOpInterface
) and so type inference for tensor types can be implemented with InferShapedTypeOpInterface
.
The C++ interfaces are the base mechanism whereby shape inference is queried and executed, but not the intended way to specify shape constraints in general.
Initially the shape inference will be declaratively specified using:
Constraints on the operands of an operation directly. For example constraining the input type to be tensor/vector elements or that the elemental type be of a specific type (e.g., output of computing the size of a value is of elemental type i1
) or class (e.g., float-like).
Constraints across operands and results of an operation.
NOTE: The C++ shape functions are an intermediate step until the shape dialect is more full-fledged, at which point the C++ functions should become the exceptional case.
Shape inference is currently tested alongside type inference by TestReturnTypeDriver
in the test dialect. This driver performs two checks:
testCreateFunctions
by creating new binary Ops (Op classes specified in TestReturnTypeDriver
) using 1) all operands to testCreateFunctions
as both operands, and 2) using combinations of input operands of the function.This section details the shape type inference dialect (shape
). The initial focus will be on shape functions that describe shape functions could be used in runtime and compiler (for constructions of ops/refinement of shapes, reification of dynamic allocations for dialect including TF, TFLite, XLA & tensor compute dialect under discussion).
This will focus on the shape functions (e.g., determine the rank and dimensions of the output shape). As shown in the shaped container type, shape will be one of 3 components, the others being elemental type and attribute (which is currently left open with the intention of supporting extensions such as layouts or bounded shapes at a later point). This allows for decoupling of these:
An argument could be made that these are metadata function instead of shape functions, with some considering shape and elemental types different and some considering them both as part of shape. But shape function
is IMHO descriptive and metadata can span too large a range of potential uses/values.
The requirements for the shape inference functions are determined by the requirements of shape inference, but we believe the requirements below still allow freedom to consider different shape inference approaches and so we do not impose a particular shape inference approach here.
Expressiveness shape functions need to support programs where tensors have shapes that are not known statically (for example, tensor<16x?xf32>
or tensor<*xf32>*
);
Shape error detection Many operations will have constraints on their operands. If the constraints are not satisfied or cannot be determined if satisfied statically, then a runtime check/assertion could be generated.
Shape functions usable by compiler and runtime.
Declarative (e.g., analyzable at compile time, possible to generate different versions for different use cases)
Shape inference functions are expressible at runtime
User can define a shape function for a new operation dynamically at runtime, this allows for vendors to describe an operation and shape function dynamically.
This requirement is on the wishlist.
Doesn't require graph-wide shape information (e.g., only require local information)
TensorList<F32>
);Shape functions should be pure functions.
Should support functions whose type is only known dynamically (e.g., read_from_file
op)
The shape function operation dialect should be interoperable with non-shape function dialect operations.
There may be a common set of operations that satisfy most uses (e.g., merge, equal_type, arithmetic expressions, slice, concat, pattern matching on attributes such as padding etc.) that will be discovered and could cover a large percentage of the use cases. Among these there will be some which carry extra semantic info that could be used for symbolic constraints (e.g., checking equality of two dimensions resulting in setting an equality constraint) and higher-order interpretation for constraint solving.
It is therefore beneficial (but not required) to reuse operations, especially as for statically known shapes, arbitrary arithmetic computations could still be performed. This means that the computations performed statically may or may not be supported by an arbitrary solver, but would still be allowed.
The shape function should be expandable such that symbolic equality and upper bound constraints (say) could be represented and may be propagated by shape inference.
Shape functions are allowed to fail and report an error. The error reporting should report the location of the operation that failed with, where possible, a user actionable error message.
SetShapeFn
calls, etc., is still TBD.Shape functions should be lowerable to runtime checks for validity. E.g. verify as much as possible statically, but enable generating instructions to compute the shape dynamically and or falling back to runtime checks for attributes not verifiable at compile time. These checks inserted should ideally only check that which could not have been verified statically.
These inlined calls could interfere with optimization patterns/passes (e.g., shape inference should not insert constructs that interfere with optimization patterns) and so could be delayed until later (with another round of optimizations, constant folding, CSE, etc., that should remove redundant runtime operations).
In ODS we have been recording the constraints for the operands & attributes of an operation. Where these are sufficient to constrain the output shape (e.g., SameOperandAndResultType
or broadcastable) we should generate the shape function from those. Where not, an explicit shape function should be specified (spelling TBD but currently considering using the MLIR textual form as serialization approach).
This could be done in future! The extracted shape function would use the shape inference dialect, so we are starting there. Especially for operations described in a structured way, one could autogenerate the shape function.
TBD. open to many approaches and suggestions, starting on the IR produced by whatever language is the priority of this proposal.
None. There are multiple different shape inference approaches that we could layer on top of these. From the most basic (always return unranked), to more useful (return fixed shape for constant inputs/arguments) to the more advanced (create logical conjunctions of algebraic statements between symbolic named values).
TODO: Add examples here.
Shape functions are determined by attributes and could be arbitrarily complicated with a wide-range of specification possibilities. Equality relationships are common (e.g., the elemental type of the output matches the primitive type of the inputs, both inputs have exactly the same type [primitive type and shape]) and so these should be easy to specify. Algebraic relationships would also be common (e.g., a concat of [n,m]
and [n,m]
matrix along axis 0 is [n+n, m]
matrix), while some ops only have defined shapes under certain cases (e.g., matrix multiplication of [a,b]
and [c,d]
is only defined if b == c
).
Instead of specifying an additional mechanism to specify a shape transfer function, the reference implementation of the operation will be used to derive the shape function. The reference implementation is general and can support the arbitrary computations needed to specify output shapes.