| # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| # See https://llvm.org/LICENSE.txt for license information. |
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| |
| |
| from ._scf_ops_gen import * |
| from ._scf_ops_gen import _Dialect |
| from .arith import constant |
| |
| try: |
| from ..ir import * |
| from ._ods_common import ( |
| get_op_result_or_value as _get_op_result_or_value, |
| get_op_results_or_values as _get_op_results_or_values, |
| _cext as _ods_cext, |
| ) |
| except ImportError as e: |
| raise RuntimeError("Error loading imports from extension module") from e |
| |
| from typing import Optional, Sequence, Union |
| |
| |
| @_ods_cext.register_operation(_Dialect, replace=True) |
| class ForOp(ForOp): |
| """Specialization for the SCF for op class.""" |
| |
| def __init__( |
| self, |
| lower_bound, |
| upper_bound, |
| step, |
| iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, |
| *, |
| loc=None, |
| ip=None, |
| ): |
| """Creates an SCF `for` operation. |
| |
| - `lower_bound` is the value to use as lower bound of the loop. |
| - `upper_bound` is the value to use as upper bound of the loop. |
| - `step` is the value to use as loop step. |
| - `iter_args` is a list of additional loop-carried arguments or an operation |
| producing them as results. |
| """ |
| if iter_args is None: |
| iter_args = [] |
| iter_args = _get_op_results_or_values(iter_args) |
| |
| results = [arg.type for arg in iter_args] |
| super().__init__( |
| results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip |
| ) |
| self.regions[0].blocks.append(self.operands[0].type, *results) |
| |
| @property |
| def body(self): |
| """Returns the body (block) of the loop.""" |
| return self.regions[0].blocks[0] |
| |
| @property |
| def induction_variable(self): |
| """Returns the induction variable of the loop.""" |
| return self.body.arguments[0] |
| |
| @property |
| def inner_iter_args(self): |
| """Returns the loop-carried arguments usable within the loop. |
| |
| To obtain the loop-carried operands, use `iter_args`. |
| """ |
| return self.body.arguments[1:] |
| |
| |
| @_ods_cext.register_operation(_Dialect, replace=True) |
| class IfOp(IfOp): |
| """Specialization for the SCF if op class.""" |
| |
| def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None): |
| """Creates an SCF `if` operation. |
| |
| - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. |
| - `hasElse` determines whether the if operation has the else branch. |
| """ |
| if results_ is None: |
| results_ = [] |
| operands = [] |
| operands.append(cond) |
| results = [] |
| results.extend(results_) |
| super().__init__(results, cond) |
| self.regions[0].blocks.append(*[]) |
| if hasElse: |
| self.regions[1].blocks.append(*[]) |
| |
| @property |
| def then_block(self): |
| """Returns the then block of the if operation.""" |
| return self.regions[0].blocks[0] |
| |
| @property |
| def else_block(self): |
| """Returns the else block of the if operation.""" |
| return self.regions[1].blocks[0] |
| |
| |
| def for_( |
| start, |
| stop=None, |
| step=None, |
| iter_args: Optional[Sequence[Value]] = None, |
| *, |
| loc=None, |
| ip=None, |
| ): |
| if step is None: |
| step = 1 |
| if stop is None: |
| stop = start |
| start = 0 |
| params = [start, stop, step] |
| for i, p in enumerate(params): |
| if isinstance(p, int): |
| p = constant(IndexType.get(), p) |
| elif isinstance(p, float): |
| raise ValueError(f"{p=} must be int.") |
| params[i] = p |
| |
| start, stop, step = params |
| |
| for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) |
| iv = for_op.induction_variable |
| iter_args = tuple(for_op.inner_iter_args) |
| with InsertionPoint(for_op.body): |
| if len(iter_args) > 1: |
| yield iv, iter_args, for_op.results |
| elif len(iter_args) == 1: |
| yield iv, iter_args[0], for_op.results[0] |
| else: |
| yield iv |