| # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from typing import ( |
| List as _List, |
| Optional as _Optional, |
| Sequence as _Sequence, |
| Tuple as _Tuple, |
| Type as _Type, |
| Union as _Union, |
| ) |
| |
| from .._mlir_libs import _mlir as _cext |
| from ..ir import ( |
| ArrayAttr, |
| Attribute, |
| BoolAttr, |
| DenseI64ArrayAttr, |
| IntegerAttr, |
| IntegerType, |
| OpView, |
| Operation, |
| ShapedType, |
| Value, |
| ) |
| |
| __all__ = [ |
| "equally_sized_accessor", |
| "get_default_loc_context", |
| "get_op_result_or_value", |
| "get_op_results_or_values", |
| "get_op_result_or_op_results", |
| "segmented_accessor", |
| ] |
| |
| |
| def segmented_accessor(elements, raw_segments, idx): |
| """ |
| Returns a slice of elements corresponding to the idx-th segment. |
| |
| elements: a sliceable container (operands or results). |
| raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing |
| sizes of the segments. |
| idx: index of the segment. |
| """ |
| segments = _cext.ir.DenseI32ArrayAttr(raw_segments) |
| start = sum(segments[i] for i in range(idx)) |
| end = start + segments[idx] |
| return elements[start:end] |
| |
| |
| def equally_sized_accessor( |
| elements, n_variadic, n_preceding_simple, n_preceding_variadic |
| ): |
| """ |
| Returns a starting position and a number of elements per variadic group |
| assuming equally-sized groups and the given numbers of preceding groups. |
| |
| elements: a sequential container. |
| n_variadic: the number of variadic groups in the container. |
| n_preceding_simple: the number of non-variadic groups preceding the current |
| group. |
| n_preceding_variadic: the number of variadic groups preceding the current |
| group. |
| """ |
| |
| total_variadic_length = len(elements) - n_variadic + 1 |
| # This should be enforced by the C++-side trait verifier. |
| assert total_variadic_length % n_variadic == 0 |
| |
| elements_per_group = total_variadic_length // n_variadic |
| start = n_preceding_simple + n_preceding_variadic * elements_per_group |
| return start, elements_per_group |
| |
| |
| def get_default_loc_context(location=None): |
| """ |
| Returns a context in which the defaulted location is created. If the location |
| is None, takes the current location from the stack, raises ValueError if there |
| is no location on the stack. |
| """ |
| if location is None: |
| # Location.current raises ValueError if there is no current location. |
| return _cext.ir.Location.current.context |
| return location.context |
| |
| |
| def get_op_result_or_value( |
| arg: _Union[ |
| _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList |
| ] |
| ) -> _cext.ir.Value: |
| """Returns the given value or the single result of the given op. |
| |
| This is useful to implement op constructors so that they can take other ops as |
| arguments instead of requiring the caller to extract results for every op. |
| Raises ValueError if provided with an op that doesn't have a single result. |
| """ |
| if isinstance(arg, _cext.ir.OpView): |
| return arg.operation.result |
| elif isinstance(arg, _cext.ir.Operation): |
| return arg.result |
| elif isinstance(arg, _cext.ir.OpResultList): |
| return arg[0] |
| else: |
| assert isinstance(arg, _cext.ir.Value) |
| return arg |
| |
| |
| def get_op_results_or_values( |
| arg: _Union[ |
| _cext.ir.OpView, |
| _cext.ir.Operation, |
| _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], |
| ] |
| ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: |
| """Returns the given sequence of values or the results of the given op. |
| |
| This is useful to implement op constructors so that they can take other ops as |
| lists of arguments instead of requiring the caller to extract results for |
| every op. |
| """ |
| if isinstance(arg, _cext.ir.OpView): |
| return arg.operation.results |
| elif isinstance(arg, _cext.ir.Operation): |
| return arg.results |
| else: |
| return [get_op_result_or_value(element) for element in arg] |
| |
| |
| def get_op_result_or_op_results( |
| op: _Union[_cext.ir.OpView, _cext.ir.Operation], |
| ) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: |
| if isinstance(op, _cext.ir.OpView): |
| op = op.operation |
| return ( |
| list(get_op_results_or_values(op)) |
| if len(op.results) > 1 |
| else get_op_result_or_value(op) |
| if len(op.results) > 0 |
| else op |
| ) |
| |
| ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value |
| ResultValueT = _Union[ResultValueTypeTuple] |
| VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] |
| |
| StaticIntLike = _Union[int, IntegerAttr] |
| ValueLike = _Union[Operation, OpView, Value] |
| MixedInt = _Union[StaticIntLike, ValueLike] |
| |
| IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] |
| OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] |
| |
| BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] |
| OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] |
| |
| MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] |
| |
| DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] |
| |
| |
| def _dispatch_dynamic_index_list( |
| indices: _Union[DynamicIndexList, ArrayAttr], |
| ) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: |
| """Dispatches a list of indices to the appropriate form. |
| |
| This is similar to the custom `DynamicIndexList` directive upstream: |
| provided indices may be in the form of dynamic SSA values or static values, |
| and they may be scalable (i.e., as a singleton list) or not. This function |
| dispatches each index into its respective form. It also extracts the SSA |
| values and static indices from various similar structures, respectively. |
| """ |
| dynamic_indices = [] |
| static_indices = [ShapedType.get_dynamic_size()] * len(indices) |
| scalable_indices = [False] * len(indices) |
| |
| # ArrayAttr: Extract index values. |
| if isinstance(indices, ArrayAttr): |
| indices = [idx for idx in indices] |
| |
| def process_nonscalable_index(i, index): |
| """Processes any form of non-scalable index. |
| |
| Returns False if the given index was scalable and thus remains |
| unprocessed; True otherwise. |
| """ |
| if isinstance(index, int): |
| static_indices[i] = index |
| elif isinstance(index, IntegerAttr): |
| static_indices[i] = index.value # pytype: disable=attribute-error |
| elif isinstance(index, (Operation, Value, OpView)): |
| dynamic_indices.append(index) |
| else: |
| return False |
| return True |
| |
| # Process each index at a time. |
| for i, index in enumerate(indices): |
| if not process_nonscalable_index(i, index): |
| # If it wasn't processed, it must be a scalable index, which is |
| # provided as a _Sequence of one value, so extract and process that. |
| scalable_indices[i] = True |
| assert len(index) == 1 |
| ret = process_nonscalable_index(i, index[0]) |
| assert ret |
| |
| return dynamic_indices, static_indices, scalable_indices |
| |
| |
| # Dispatches `MixedValues` that all represents integers in various forms into |
| # the following three categories: |
| # - `dynamic_values`: a list of `Value`s, potentially from op results; |
| # - `packed_values`: a value handle, potentially from an op result, associated |
| # to one or more payload operations of integer type; |
| # - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python |
| # `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. |
| # The input is in the form for `packed_values`, only that result is set and the |
| # other two are empty. Otherwise, the input can be a mix of the other two forms, |
| # and for each dynamic value, a special value is added to the `static_values`. |
| def _dispatch_mixed_values( |
| values: MixedValues, |
| ) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: |
| dynamic_values = [] |
| packed_values = None |
| static_values = None |
| if isinstance(values, ArrayAttr): |
| static_values = values |
| elif isinstance(values, (Operation, Value, OpView)): |
| packed_values = values |
| else: |
| static_values = [] |
| for size in values or []: |
| if isinstance(size, int): |
| static_values.append(size) |
| else: |
| static_values.append(ShapedType.get_dynamic_size()) |
| dynamic_values.append(size) |
| static_values = DenseI64ArrayAttr.get(static_values) |
| |
| return (dynamic_values, packed_values, static_values) |
| |
| |
| def _get_value_or_attribute_value( |
| value_or_attr: _Union[any, Attribute, ArrayAttr] |
| ) -> any: |
| if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): |
| return value_or_attr.value |
| if isinstance(value_or_attr, ArrayAttr): |
| return _get_value_list(value_or_attr) |
| return value_or_attr |
| |
| |
| def _get_value_list( |
| sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] |
| ) -> _Sequence[any]: |
| return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] |
| |
| |
| def _get_int_array_attr( |
| values: _Optional[_Union[ArrayAttr, IntOrAttrList]] |
| ) -> ArrayAttr: |
| if values is None: |
| return None |
| |
| # Turn into a Python list of Python ints. |
| values = _get_value_list(values) |
| |
| # Make an ArrayAttr of IntegerAttrs out of it. |
| return ArrayAttr.get( |
| [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] |
| ) |
| |
| |
| def _get_int_array_array_attr( |
| values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] |
| ) -> ArrayAttr: |
| """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. |
| |
| The input has to be a collection of a collection of integers, where any |
| Python _Sequence and ArrayAttr are admissible collections and Python ints and |
| any IntegerAttr are admissible integers. Both levels of collections are |
| turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. |
| If the input is None, an empty ArrayAttr is returned. |
| """ |
| if values is None: |
| return None |
| |
| # Make sure the outer level is a list. |
| values = _get_value_list(values) |
| |
| # The inner level is now either invalid or a mixed sequence of ArrayAttrs and |
| # Sequences. Make sure the nested values are all lists. |
| values = [_get_value_list(nested) for nested in values] |
| |
| # Turn each nested list into an ArrayAttr. |
| values = [_get_int_array_attr(nested) for nested in values] |
| |
| # Turn the outer list into an ArrayAttr. |
| return ArrayAttr.get(values) |