blob: 24fdcbcd85b29fe80877372af960e32dd82ef0d3 [file] [log] [blame]
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._func_ops_gen import *
from ._func_ops_gen import _Dialect
try:
from ..ir import *
from ._ods_common import (
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
)
import inspect
from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"
@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
"""Specialization for the constant op class."""
@property
def type(self):
return self.results[0].type
@_ods_cext.register_operation(_Dialect, replace=True)
class FuncOp(FuncOp):
"""Specialization for the func op class."""
def __init__(
self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
):
"""
Create a FuncOp with the provided `name`, `type`, and `visibility`.
- `name` is a string representing the function name.
- `type` is either a FunctionType or a pair of list describing inputs and
results.
- `visibility` is a string matching `public`, `private`, or `nested`. None
implies private visibility.
- `body_builder` is an optional callback, when provided a new entry block
is created and the callback is invoked with the new op as argument within
an InsertionPoint context already set for the block. The callback is
expected to insert a terminator in the block.
"""
sym_name = StringAttr.get(str(name))
# If the type is passed as a tuple, build a FunctionType on the fly.
if isinstance(type, tuple):
type = FunctionType.get(inputs=type[0], results=type[1])
type = TypeAttr.get(type)
sym_visibility = (
StringAttr.get(str(visibility)) if visibility is not None else None
)
super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
if body_builder:
entry_block = self.add_entry_block()
with InsertionPoint(entry_block):
body_builder(self)
@property
def is_external(self):
return len(self.regions[0].blocks) == 0
@property
def body(self):
return self.regions[0]
@property
def type(self):
return FunctionType(TypeAttr(self.attributes["function_type"]).value)
@property
def visibility(self):
return self.attributes["sym_visibility"]
@property
def name(self) -> StringAttr:
return StringAttr(self.attributes["sym_name"])
@property
def entry_block(self):
if self.is_external:
raise IndexError("External function does not have a body")
return self.regions[0].blocks[0]
def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
"""
Add an entry block to the function body using the function signature to
infer block arguments.
Returns the newly created block
"""
if not self.is_external:
raise IndexError("The function already has an entry block!")
self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
return self.body.blocks[0]
@property
def arg_attrs(self):
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
@arg_attrs.setter
def arg_attrs(self, attribute: Union[ArrayAttr, list]):
if isinstance(attribute, ArrayAttr):
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
else:
self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
attribute, context=self.context
)
@property
def arguments(self):
return self.entry_block.arguments
@property
def result_attrs(self):
return self.attributes[RESULT_ATTRIBUTE_NAME]
@result_attrs.setter
def result_attrs(self, attribute: ArrayAttr):
self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
@classmethod
def from_py_func(
FuncOp,
*inputs: Type,
results: Optional[Sequence[Type]] = None,
name: Optional[str] = None,
):
"""Decorator to define an MLIR FuncOp specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
When applied as a decorator to a Python function, an entry block will
be constructed for the FuncOp with types as specified in `*inputs`. The
block arguments will be passed positionally to the Python function. In
addition, if the Python function accepts keyword arguments generally or
has a corresponding keyword argument, the following will be passed:
* `func_op`: The `func` op being defined.
By default, the function name will be the Python function `__name__`. This
can be overriden by passing the `name` argument to the decorator.
If `results` is not specified, then the decorator will implicitly
insert a `ReturnOp` with the `Value`'s returned from the decorated
function. It will also set the `FuncOp` type with the actual return
value types. If `results` is specified, then the decorated function
must return `None` and no implicit `ReturnOp` is added (nor are the result
types updated). The implicit behavior is intended for simple, single-block
cases, and users should specify result types explicitly for any complicated
cases.
The decorated function can further be called from Python and will insert
a `CallOp` at the then-current insertion point, returning either None (
if no return values), a unary Value (for one result), or a list of Values).
This mechanism cannot be used to emit recursive calls (by construction).
"""
def decorator(f):
from . import func
# Introspect the callable for optional features.
sig = inspect.signature(f)
has_arg_func_op = False
for param in sig.parameters.values():
if param.kind == param.VAR_KEYWORD:
has_arg_func_op = True
if param.name == "func_op" and (
param.kind == param.POSITIONAL_OR_KEYWORD
or param.kind == param.KEYWORD_ONLY
):
has_arg_func_op = True
# Emit the FuncOp.
implicit_return = results is None
symbol_name = name or f.__name__
function_type = FunctionType.get(
inputs=inputs, results=[] if implicit_return else results
)
func_op = FuncOp(name=symbol_name, type=function_type)
with InsertionPoint(func_op.add_entry_block()):
func_args = func_op.entry_block.arguments
func_kwargs = {}
if has_arg_func_op:
func_kwargs["func_op"] = func_op
return_values = f(*func_args, **func_kwargs)
if not implicit_return:
return_types = list(results)
assert return_values is None, (
"Capturing a python function with explicit `results=` "
"requires that the wrapped function returns None."
)
else:
# Coerce return values, add ReturnOp and rewrite func type.
if return_values is None:
return_values = []
elif isinstance(return_values, tuple):
return_values = list(return_values)
elif isinstance(return_values, Value):
# Returning a single value is fine, coerce it into a list.
return_values = [return_values]
elif isinstance(return_values, OpView):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.operation.results
elif isinstance(return_values, Operation):
# Returning a single operation is fine, coerce its results a list.
return_values = return_values.results
else:
return_values = list(return_values)
func.ReturnOp(return_values)
# Recompute the function type.
return_types = [v.type for v in return_values]
function_type = FunctionType.get(
inputs=inputs, results=return_types
)
func_op.attributes["function_type"] = TypeAttr.get(function_type)
def emit_call_op(*call_args):
call_op = func.CallOp(
return_types, FlatSymbolRefAttr.get(symbol_name), call_args
)
if return_types is None:
return None
elif len(return_types) == 1:
return call_op.result
else:
return call_op.results
wrapped = emit_call_op
wrapped.__name__ = f.__name__
wrapped.func_op = func_op
return wrapped
return decorator
func = FuncOp.from_py_func
@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
"""Specialization for the call op class."""
def __init__(
self,
calleeOrResults: Union[FuncOp, List[Type]],
argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
arguments: Optional[List] = None,
*,
loc=None,
ip=None,
):
"""Creates an call operation.
The constructor accepts three different forms:
1. A function op to be called followed by a list of arguments.
2. A list of result types, followed by the name of the function to be
called as string, following by a list of arguments.
3. A list of result types, followed by the name of the function to be
called as symbol reference attribute, followed by a list of arguments.
For example
f = func.FuncOp("foo", ...)
func.CallOp(f, [args])
func.CallOp([result_types], "foo", [args])
In all cases, the location and insertion point may be specified as keyword
arguments if not provided by the surrounding context managers.
"""
# TODO: consider supporting constructor "overloads", e.g., through a custom
# or pybind-provided metaclass.
if isinstance(calleeOrResults, FuncOp):
if not isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function, expected "
+ "the second argument to be a list of call arguments, "
+ f"got {type(argumentsOrCallee)}"
)
if arguments is not None:
raise ValueError(
"unexpected third argument when constructing a call"
+ "to a function"
)
super().__init__(
calleeOrResults.type.results,
FlatSymbolRefAttr.get(
calleeOrResults.name.value, context=_get_default_loc_context(loc)
),
argumentsOrCallee,
loc=loc,
ip=ip,
)
return
if isinstance(argumentsOrCallee, list):
raise ValueError(
"when constructing a call to a function by name, "
+ "expected the second argument to be a string or a "
+ f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
)
if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
super().__init__(
calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
)
elif isinstance(argumentsOrCallee, str):
super().__init__(
calleeOrResults,
FlatSymbolRefAttr.get(
argumentsOrCallee, context=_get_default_loc_context(loc)
),
arguments,
loc=loc,
ip=ip,
)