blob: 3f2defadf79412a600bc9c4bc0ca8d53c951cfb2 [file] [log] [blame]
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import inspect
from functools import wraps
from ..dialects._ods_common import get_op_result_or_op_results
from ..ir import Type, InsertionPoint
def op_region_builder(op, op_region, terminator=None):
def builder_wrapper(body_builder):
# Add a block with block args having types determined by type hints on the wrapped function.
if len(op_region.blocks) == 0:
sig = inspect.signature(body_builder)
types = [p.annotation for p in sig.parameters.values()]
if not (
len(types) == len(sig.parameters)
and all(isinstance(t, Type) for t in types)
):
raise ValueError(
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
)
op_region.blocks.append(*types)
with InsertionPoint(op_region.blocks[0]):
results = body_builder(*list(op_region.blocks[0].arguments))
with InsertionPoint(list(op_region.blocks)[-1]):
if terminator is not None:
res = []
if isinstance(results, (tuple, list)):
res.extend(results)
elif results is not None:
res.append(results)
terminator(res)
return get_op_result_or_op_results(op)
return builder_wrapper
def region_op(op_constructor, terminator=None):
"""Decorator to define an MLIR Op specified as a python function.
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
active for the current thread (i.e. established in a `with` block).
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
When applied as a decorator to a Python function, an entry block will
be constructed for the Op with types as specified **as type hints on the args of the function**.
The block arguments will be passed positionally to the Python function.
If a terminator is specified then the return from the decorated function will be passed
to the terminator as the last statement in the entry block. Note, the API for the terminator
is a (possibly empty) list; terminator accepting single values should be wrapped in a
`lambda args: term(args[0])`
The identifier (name) of the function will become:
1. A single value result if the Op returns a single value;
2. An OpResultList (as a list) if the Op returns multiple values;
3. The Operation if the Op returns no results.
See examples in tensor.py and transform.extras.
"""
def op_decorator(*args, **kwargs):
op = op_constructor(*args, **kwargs)
op_region = op.regions[0]
return op_region_builder(op, op_region, terminator)
@wraps(op_decorator)
def maybe_no_args(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return op_decorator()(args[0])
else:
return op_decorator(*args, **kwargs)
return maybe_no_args