| # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| from ._affine_ops_gen import * |
| from ._affine_ops_gen import _Dialect |
| |
| try: |
| from ..ir import * |
| from ._ods_common import ( |
| get_op_result_or_value as _get_op_result_or_value, |
| get_op_results_or_values as _get_op_results_or_values, |
| _cext as _ods_cext, |
| ResultValueTypeTuple as _ResultValueTypeTuple, |
| ResultValueT as _ResultValueT, |
| VariadicResultValueT as _VariadicResultValueT, |
| ) |
| except ImportError as e: |
| raise RuntimeError("Error loading imports from extension module") from e |
| |
| from typing import Optional, Sequence, Union |
| |
| |
| @_ods_cext.register_operation(_Dialect, replace=True) |
| class AffineForOp(AffineForOp): |
| """Specialization for the Affine for op class.""" |
| |
| def __init__( |
| self, |
| lower_bound: Union[int, _ResultValueT, AffineMap], |
| upper_bound: Optional[Union[int, _ResultValueT, AffineMap]], |
| step: Optional[Union[int, Attribute]] = None, |
| iter_args: Optional[_ResultValueT] = None, |
| *, |
| lower_bound_operands: Optional[_VariadicResultValueT] = None, |
| upper_bound_operands: Optional[_VariadicResultValueT] = None, |
| loc=None, |
| ip=None, |
| ): |
| """Creates an Affine `for` operation. |
| |
| - `lower_bound` is the affine map to use as lower bound of the loop. |
| - `upper_bound` is the affine map to use as upper bound of the loop. |
| - `step` is the value to use as loop step. |
| - `iter_args` is a list of additional loop-carried arguments or an operation |
| producing them as results. |
| - `lower_bound_operands` is the list of arguments to substitute the dimensions, |
| then symbols in the `lower_bound` affine map, in an increasing order. |
| - `upper_bound_operands` is the list of arguments to substitute the dimensions, |
| then symbols in the `upper_bound` affine map, in an increasing order. |
| """ |
| |
| if lower_bound_operands is None: |
| lower_bound_operands = [] |
| if upper_bound_operands is None: |
| upper_bound_operands = [] |
| |
| if step is None: |
| step = 1 |
| |
| bounds_operands = [lower_bound_operands, upper_bound_operands] |
| bounds = [lower_bound, upper_bound] |
| bounds_names = ["lower", "upper"] |
| for i, name in enumerate(bounds_names): |
| if isinstance(bounds[i], int): |
| bounds[i] = AffineMap.get_constant(bounds[i]) |
| elif isinstance(bounds[i], _ResultValueTypeTuple): |
| if len(bounds_operands[i]): |
| raise ValueError( |
| f"Either a concrete {name} bound or an AffineMap in combination " |
| f"with {name} bound operands, but not both, is supported." |
| ) |
| if ( |
| isinstance(bounds[i], (OpView, Operation)) |
| and len(bounds[i].results) > 1 |
| ): |
| raise ValueError( |
| f"Only a single concrete value is supported for {name} bound." |
| ) |
| |
| bounds_operands[i].append(_get_op_result_or_value(bounds[i])) |
| bounds[i] = AffineMap.get_identity(1) |
| |
| if not isinstance(bounds[i], AffineMap): |
| raise ValueError( |
| f"{name} bound must be int | ResultValueT | AffineMap." |
| ) |
| if len(bounds_operands[i]) != bounds[i].n_inputs: |
| raise ValueError( |
| f"Wrong number of {name} bound operands passed to AffineForOp; " |
| + f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}." |
| ) |
| |
| lower_bound, upper_bound = bounds |
| |
| if iter_args is None: |
| iter_args = [] |
| iter_args = _get_op_results_or_values(iter_args) |
| |
| results = [arg.type for arg in iter_args] |
| super().__init__( |
| results_=results, |
| lowerBoundOperands=_get_op_results_or_values(lower_bound_operands), |
| upperBoundOperands=_get_op_results_or_values(upper_bound_operands), |
| inits=list(iter_args), |
| lowerBoundMap=AffineMapAttr.get(lower_bound), |
| upperBoundMap=AffineMapAttr.get(upper_bound), |
| step=step, |
| loc=loc, |
| ip=ip, |
| ) |
| self.regions[0].blocks.append(IndexType.get(), *results) |
| |
| @property |
| def body(self): |
| """Returns the body (block) of the loop.""" |
| return self.regions[0].blocks[0] |
| |
| @property |
| def induction_variable(self): |
| """Returns the induction variable of the loop.""" |
| return self.body.arguments[0] |
| |
| @property |
| def inner_iter_args(self): |
| """Returns the loop-carried arguments usable within the loop. |
| |
| To obtain the loop-carried operands, use `iter_args`. |
| """ |
| return self.body.arguments[1:] |
| |
| |
| def for_( |
| start, |
| stop, |
| step=None, |
| iter_args: Optional[Sequence[Value]] = None, |
| *, |
| loc=None, |
| ip=None, |
| ): |
| for_op = AffineForOp( |
| start, |
| stop, |
| step, |
| iter_args=iter_args, |
| loc=loc, |
| ip=ip, |
| ) |
| iv = for_op.induction_variable |
| iter_args = tuple(for_op.inner_iter_args) |
| with InsertionPoint(for_op.body): |
| if len(iter_args) > 1: |
| yield iv, iter_args |
| elif len(iter_args) == 1: |
| yield iv, iter_args[0] |
| else: |
| yield iv |