blob: d6c57547ee1634c0800273fce9abfc5e239bcbb8 [file] [log] [blame]
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
try:
from typing import Optional, Sequence, Union
from ..ir import *
from ._ods_common import get_default_loc_context
from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
def isa(cls: Type, ty: Type):
try:
cls(ty)
return True
except ValueError:
return False
class FillOp:
"""Extends the linalg.fill op."""
def __init__(self, output: Value, value: Value, *, loc=None, ip=None):
results = []
if isa(RankedTensorType, output.type):
results = [output.type]
op = self.build_generic(
results=results,
operands=[_get_op_result_or_value(o) for o in [value, output]],
attributes=None,
loc=loc,
ip=ip)
OpView.__init__(self, op)
linalgDialect = Context.current.get_dialect_descriptor("linalg")
fill_builtin_region(linalgDialect, self.operation)
class InitTensorOp:
"""Extends the linalg.init_tensor op."""
def __init__(self,
sizes: Union[Sequence[int], Sequence[Value]],
element_type: Type,
*,
loc=None,
ip=None):
"""Constructs an `init_tensor` with either static or dynamic sizes."""
context = get_default_loc_context(loc)
operands = []
attributes = {}
# TODO: Refactor the InitTensorOp to take an element type attribute and
# then use normal result type inference, unifying the Python and C++ side
# with a standard mechanism (versus stashing that in builders).
if sizes and isinstance(sizes[0], Value):
# Dynamic sizes.
operands.extend(sizes)
static_size_ints = [-1] * len(sizes)
result_type = RankedTensorType.get(static_size_ints, element_type)
else:
# Static sizes.
result_type = RankedTensorType.get(sizes, element_type)
static_size_ints = sizes
i64_type = IntegerType.get_signless(64)
attributes["static_sizes"] = ArrayAttr.get(
[IntegerAttr.get(i64_type, s) for s in static_size_ints],
context=context)
op = self.build_generic(results=[result_type],
operands=operands,
attributes=attributes,
loc=loc,
ip=ip)
OpView.__init__(self, op)
class StructuredOpMixin:
"""All structured ops use the same mixin class."""
def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
super().__init__(
self.build_generic(results=list(results),
operands=[list(inputs), list(outputs)],
loc=loc,
ip=ip))
def select_opview_mixin(parent_opview_cls):
# TODO: This shouldn't be a heuristic: we should have a way to annotate
# the OpView to note that it is a structured op.
if ("__init__" not in parent_opview_cls.__dict__ and
hasattr(parent_opview_cls, "inputs") and
hasattr(parent_opview_cls, "outputs") and
hasattr(parent_opview_cls, "result_tensors")):
return StructuredOpMixin