# RUN: %PYTHON %s | FileCheck %s
import sys
import typing
from typing import Union, Optional

from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
import mlir.dialects.tensor as tensor
import mlir.dialects.arith as arith

from mlir._mlir_libs._mlirPythonTestNanobind import (
    TestAttr,
    TestType,
    TestTensorValue,
    TestIntegerRankedTensorType,
)

test.register_python_test_dialect(get_dialect_registry())


def run(f):
    print("\nTEST:", f.__name__)
    f()
    return f


# CHECK-LABEL: TEST: testAttributes
@run
def testAttributes():
    with Context() as ctx, Location.unknown():
        #
        # Check op construction with attributes.
        #

        i32 = IntegerType.get_signless(32)
        one = IntegerAttr.get(i32, 1)
        two = IntegerAttr.get(i32, 2)
        unit = UnitAttr.get()

        # CHECK: python_test.attributed_op  {
        # CHECK-DAG: mandatory_i32 = 1 : i32
        # CHECK-DAG: optional_i32 = 2 : i32
        # CHECK-DAG: unit
        # CHECK: }
        op = test.AttributedOp(one, optional_i32=two, unit=unit)
        print(f"{op}")

        # CHECK: python_test.attributed_op  {
        # CHECK: mandatory_i32 = 2 : i32
        # CHECK: }
        op2 = test.AttributedOp(two)
        print(f"{op2}")

        #
        # Check generic "attributes" access and mutation.
        #

        assert "additional" not in op.attributes

        # CHECK: python_test.attributed_op  {
        # CHECK-DAG: additional = 1 : i32
        # CHECK-DAG: mandatory_i32 = 2 : i32
        # CHECK: }
        op2.attributes["additional"] = one
        print(f"{op2}")

        # CHECK: python_test.attributed_op  {
        # CHECK-DAG: additional = 2 : i32
        # CHECK-DAG: mandatory_i32 = 2 : i32
        # CHECK: }
        op2.attributes["additional"] = two
        print(f"{op2}")

        # CHECK: python_test.attributed_op  {
        # CHECK-NOT: additional = 2 : i32
        # CHECK:     mandatory_i32 = 2 : i32
        # CHECK: }
        del op2.attributes["additional"]
        print(f"{op2}")

        try:
            print(op.attributes["additional"])
        except KeyError:
            pass
        else:
            assert False, "expected KeyError on unknown attribute key"

        #
        # Check accessors to defined attributes.
        #

        # CHECK: Mandatory: 1
        # CHECK: Optional: 2
        # CHECK: Unit: True
        print(f"Mandatory: {op.mandatory_i32.value}")
        print(f"Optional: {op.optional_i32.value}")
        print(f"Unit: {op.unit}")

        # CHECK: Mandatory: 2
        # CHECK: Optional: None
        # CHECK: Unit: False
        print(f"Mandatory: {op2.mandatory_i32.value}")
        print(f"Optional: {op2.optional_i32}")
        print(f"Unit: {op2.unit}")

        # CHECK: Mandatory: 2
        # CHECK: Optional: None
        # CHECK: Unit: False
        op.mandatory_i32 = two
        op.optional_i32 = None
        op.unit = False
        print(f"Mandatory: {op.mandatory_i32.value}")
        print(f"Optional: {op.optional_i32}")
        print(f"Unit: {op.unit}")
        assert "optional_i32" not in op.attributes
        assert "unit" not in op.attributes

        try:
            op.mandatory_i32 = None
        except ValueError:
            pass
        else:
            assert False, "expected ValueError on setting a mandatory attribute to None"

        # CHECK: Optional: 2
        op.optional_i32 = two
        print(f"Optional: {op.optional_i32.value}")

        # CHECK: Optional: None
        del op.optional_i32
        print(f"Optional: {op.optional_i32}")

        # CHECK: Unit: False
        op.unit = None
        print(f"Unit: {op.unit}")
        assert "unit" not in op.attributes

        # CHECK: Unit: True
        op.unit = True
        print(f"Unit: {op.unit}")

        # CHECK: Unit: False
        del op.unit
        print(f"Unit: {op.unit}")


# CHECK-LABEL: TEST: attrBuilder
@run
def attrBuilder():
    with Context() as ctx, Location.unknown():
        # CHECK: python_test.attributes_op
        op = test.AttributesOp(
            # CHECK-DAG: x_affinemap = affine_map<() -> (2)>
            x_affinemap=AffineMap.get_constant(2),
            # CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
            x_affinemaparr=[AffineMap.get_identity(3)],
            # CHECK-DAG: x_arr = [true, "x"]
            x_arr=[BoolAttr.get(True), StringAttr.get("x")],
            x_boolarr=[False, True],  # CHECK-DAG: x_boolarr = [false, true]
            x_bool=True,  # CHECK-DAG: x_bool = true
            x_dboolarr=[True, False],  # CHECK-DAG: x_dboolarr = array<i1: true, false>
            x_df16arr=[21, 22],  # CHECK-DAG: x_df16arr = array<i16: 21, 22>
            # CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
            x_df32arr=[23, 24],
            # CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
            x_df64arr=[25, 26],
            x_di32arr=[0, 1],  # CHECK-DAG: x_di32arr = array<i32: 0, 1>
            # CHECK-DAG: x_di64arr = array<i64: 1, 2>
            x_di64arr=[1, 2],
            x_di8arr=[2, 3],  # CHECK-DAG: x_di8arr = array<i8: 2, 3>
            # CHECK-DAG: x_dictarr = [{a = false}]
            x_dictarr=[{"a": BoolAttr.get(False)}],
            x_dict={"b": BoolAttr.get(True)},  # CHECK-DAG: x_dict = {b = true}
            x_f32=-2.25,  # CHECK-DAG: x_f32 = -2.250000e+00 : f32
            # CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
            x_f32arr=[2.0, 3.0],
            x_f64=4.25,  # CHECK-DAG: x_f64 = 4.250000e+00 : f64
            x_f64arr=[4.0, 8.0],  # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
            # CHECK-DAG: x_f64elems = dense<[8.000000e+00, 1.600000e+01]> : tensor<2xf64>
            x_f64elems=[8.0, 16.0],
            # CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
            x_flatsymrefarr=["symbol1", "symbol2"],
            x_flatsymref="symbol3",  # CHECK-DAG: x_flatsymref = @symbol3
            x_i1=0,  # CHECK-DAG: x_i1 = false
            x_i16=42,  # CHECK-DAG: x_i16 = 42 : i16
            x_i32=6,  # CHECK-DAG: x_i32 = 6 : i32
            x_i32arr=[4, 5],  # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
            x_i32elems=[5, 6],  # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xi32>
            x_i64=9,  # CHECK-DAG: x_i64 = 9 : i64
            x_i64arr=[7, 8],  # CHECK-DAG: x_i64arr = [7, 8]
            x_i64elems=[8, 9],  # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xi64>
            x_i64svecarr=[10, 11],  # CHECK-DAG: x_i64svecarr = [10, 11]
            x_i8=11,  # CHECK-DAG: x_i8 = 11 : i8
            x_idx=10,  # CHECK-DAG: x_idx = 10 : index
            # CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
            x_idxelems=[11, 12],
            # CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
            x_idxlistarr=[[13], [14, 15]],
            x_si1=-1,  # CHECK-DAG: x_si1 = -1 : si1
            x_si16=-2,  # CHECK-DAG: x_si16 = -2 : si16
            x_si32=-3,  # CHECK-DAG: x_si32 = -3 : si32
            x_si64=-123,  # CHECK-DAG: x_si64 = -123 : si64
            x_si8=-4,  # CHECK-DAG: x_si8 = -4 : si8
            x_strarr=["hello", "world"],  # CHECK-DAG: x_strarr = ["hello", "world"]
            x_str="hello world!",  # CHECK-DAG: x_str = "hello world!"
            # CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
            x_symrefarr=["flatsym", ["deep", "sym"]],
            x_symref=["deep", "sym2"],  # CHECK-DAG: x_symref = @deep::@sym2
            x_sym="symbol",  # CHECK-DAG: x_sym = "symbol"
            x_typearr=[F32Type.get()],  # CHECK-DAG: x_typearr = [f32]
            x_type=F64Type.get(),  # CHECK-DAG: x_type = f64
            x_ui1=1,  # CHECK-DAG: x_ui1 = 1 : ui1
            x_ui16=2,  # CHECK-DAG: x_ui16 = 2 : ui16
            x_ui32=3,  # CHECK-DAG: x_ui32 = 3 : ui32
            x_ui64=4,  # CHECK-DAG: x_ui64 = 4 : ui64
            x_ui8=5,  # CHECK-DAG: x_ui8 = 5 : ui8
            x_unit=True,  # CHECK-DAG: x_unit
        )
        op.verify()
        op.print(use_local_scope=True)

    # fmt: off
    assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"] is ArrayAttr
    assert type(op.x_affinemaparr) is typing.get_type_hints(test.AttributesOp.x_affinemaparr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_affinemap.fset)["value"] is AffineMapAttr
    assert typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"] is AffineMapAttr
    assert type(op.x_affinemap) is typing.get_type_hints(test.AttributesOp.x_affinemap.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_arr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"] is ArrayAttr
    assert type(op.x_arr) is typing.get_type_hints(test.AttributesOp.x_arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_boolarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"] is ArrayAttr
    assert type(op.x_boolarr) is typing.get_type_hints(test.AttributesOp.x_boolarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_bool.fset)["value"] is BoolAttr
    assert typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"] is BoolAttr
    assert type(op.x_bool) is typing.get_type_hints(test.AttributesOp.x_bool.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fset)["value"] is DenseBoolArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"] is DenseBoolArrayAttr
    assert type(op.x_dboolarr) is typing.get_type_hints(test.AttributesOp.x_dboolarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_df32arr.fset)["value"] is DenseF32ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"] is DenseF32ArrayAttr
    assert type(op.x_df32arr) is typing.get_type_hints(test.AttributesOp.x_df32arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_df64arr.fset)["value"] is DenseF64ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"] is DenseF64ArrayAttr
    assert type(op.x_df64arr) is typing.get_type_hints(test.AttributesOp.x_df64arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_df16arr.fset)["value"] is DenseI16ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"] is DenseI16ArrayAttr
    assert type(op.x_df16arr) is typing.get_type_hints(test.AttributesOp.x_df16arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_di32arr.fset)["value"] is DenseI32ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"] is DenseI32ArrayAttr
    assert type(op.x_di32arr) is typing.get_type_hints(test.AttributesOp.x_di32arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_di64arr.fset)["value"] is DenseI64ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"] is DenseI64ArrayAttr
    assert type(op.x_di64arr) is typing.get_type_hints(test.AttributesOp.x_di64arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_di8arr.fset)["value"] is DenseI8ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"] is DenseI8ArrayAttr
    assert type(op.x_di8arr) is typing.get_type_hints(test.AttributesOp.x_di8arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_dictarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"] is ArrayAttr
    assert type(op.x_dictarr) is typing.get_type_hints(test.AttributesOp.x_dictarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_dict.fset)["value"] is DictAttr
    assert typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"] is DictAttr
    assert type(op.x_dict) is typing.get_type_hints(test.AttributesOp.x_dict.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_f32arr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"] is ArrayAttr
    assert type(op.x_f32arr) is typing.get_type_hints(test.AttributesOp.x_f32arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_f32.fset)["value"] is FloatAttr
    assert typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"] is FloatAttr
    assert type(op.x_f32) is typing.get_type_hints(test.AttributesOp.x_f32.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_f64arr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"] is ArrayAttr
    assert type(op.x_f64arr) is typing.get_type_hints(test.AttributesOp.x_f64arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_f64.fset)["value"] is FloatAttr
    assert typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"] is FloatAttr
    assert type(op.x_f64) is typing.get_type_hints(test.AttributesOp.x_f64.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_f64elems.fset)["value"] is DenseFPElementsAttr
    assert typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"] is DenseFPElementsAttr
    assert type(op.x_f64elems) is typing.get_type_hints(test.AttributesOp.x_f64elems.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"] is ArrayAttr
    assert type(op.x_flatsymrefarr) is typing.get_type_hints(test.AttributesOp.x_flatsymrefarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fset)["value"] is FlatSymbolRefAttr
    assert typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"] is FlatSymbolRefAttr
    assert type(op.x_flatsymref) is typing.get_type_hints(test.AttributesOp.x_flatsymref.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i16.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"] is IntegerAttr
    assert type(op.x_i16) is typing.get_type_hints(test.AttributesOp.x_i16.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i1.fset)["value"] is BoolAttr
    assert typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"] is BoolAttr
    assert type(op.x_i1) is typing.get_type_hints(test.AttributesOp.x_i1.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i32arr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"] is ArrayAttr
    assert type(op.x_i32arr) is typing.get_type_hints(test.AttributesOp.x_i32arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i32.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"] is IntegerAttr
    assert type(op.x_i32) is typing.get_type_hints(test.AttributesOp.x_i32.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i32elems.fset)["value"] is DenseIntElementsAttr
    assert typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"] is DenseIntElementsAttr
    assert type(op.x_i32elems) is typing.get_type_hints(test.AttributesOp.x_i32elems.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i64arr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"] is ArrayAttr
    assert type(op.x_i64arr) is typing.get_type_hints(test.AttributesOp.x_i64arr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i64.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"] is IntegerAttr
    assert type(op.x_i64) is typing.get_type_hints(test.AttributesOp.x_i64.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i64elems.fset)["value"] is DenseIntElementsAttr
    assert typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"] is DenseIntElementsAttr
    assert type(op.x_i64elems) is typing.get_type_hints(test.AttributesOp.x_i64elems.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"] is ArrayAttr
    assert type(op.x_i64svecarr) is typing.get_type_hints(test.AttributesOp.x_i64svecarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_i8.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"] is IntegerAttr
    assert type(op.x_i8) is typing.get_type_hints(test.AttributesOp.x_i8.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_idx.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"] is IntegerAttr
    assert type(op.x_idx) is typing.get_type_hints(test.AttributesOp.x_idx.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_idxelems.fset)["value"] is DenseIntElementsAttr
    assert typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"] is DenseIntElementsAttr
    assert type(op.x_idxelems) is typing.get_type_hints(test.AttributesOp.x_idxelems.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"] is ArrayAttr
    assert type(op.x_idxlistarr) is typing.get_type_hints(test.AttributesOp.x_idxlistarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_si16.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"] is IntegerAttr
    assert type(op.x_si16) is typing.get_type_hints(test.AttributesOp.x_si16.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_si1.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"] is IntegerAttr
    assert type(op.x_si1) is typing.get_type_hints(test.AttributesOp.x_si1.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_si32.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"] is IntegerAttr
    assert type(op.x_si32) is typing.get_type_hints(test.AttributesOp.x_si32.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_si64.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"] is IntegerAttr
    assert type(op.x_si64) is typing.get_type_hints(test.AttributesOp.x_si64.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_si8.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"] is IntegerAttr
    assert type(op.x_si8) is typing.get_type_hints(test.AttributesOp.x_si8.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_strarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"] is ArrayAttr
    assert type(op.x_strarr) is typing.get_type_hints(test.AttributesOp.x_strarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_str.fset)["value"] is StringAttr
    assert typing.get_type_hints(test.AttributesOp.x_str.fget)["return"] is StringAttr
    assert type(op.x_str) is typing.get_type_hints(test.AttributesOp.x_str.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_sym.fset)["value"] is StringAttr
    assert typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"] is StringAttr
    assert type(op.x_sym) is typing.get_type_hints(test.AttributesOp.x_sym.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"] is ArrayAttr
    assert type(op.x_symrefarr) is typing.get_type_hints(test.AttributesOp.x_symrefarr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_symref.fset)["value"] is SymbolRefAttr
    assert typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"] is SymbolRefAttr
    assert type(op.x_symref) is typing.get_type_hints(test.AttributesOp.x_symref.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_typearr.fset)["value"] is ArrayAttr
    assert typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"] is ArrayAttr
    assert type(op.x_typearr) is typing.get_type_hints(test.AttributesOp.x_typearr.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_type.fset)["value"] is TypeAttr
    assert typing.get_type_hints(test.AttributesOp.x_type.fget)["return"] is TypeAttr
    assert type(op.x_type) is typing.get_type_hints(test.AttributesOp.x_type.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_ui16.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"] is IntegerAttr
    assert type(op.x_ui16) is typing.get_type_hints(test.AttributesOp.x_ui16.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_ui1.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"] is IntegerAttr
    assert type(op.x_ui1) is typing.get_type_hints(test.AttributesOp.x_ui1.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_ui32.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"] is IntegerAttr
    assert type(op.x_ui32) is typing.get_type_hints(test.AttributesOp.x_ui32.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_ui64.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"] is IntegerAttr
    assert type(op.x_ui64) is typing.get_type_hints(test.AttributesOp.x_ui64.fget)["return"]

    assert typing.get_type_hints(test.AttributesOp.x_ui8.fset)["value"] is IntegerAttr
    assert typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"] is IntegerAttr
    assert type(op.x_ui8) is typing.get_type_hints(test.AttributesOp.x_ui8.fget)["return"]
    # fmt: on


# CHECK-LABEL: TEST: inferReturnTypes
@run
def inferReturnTypes():
    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            op = test.InferResultsOp()
            dummy = test.DummyOp()

        # CHECK: [Type(i32), Type(i64)]
        iface = InferTypeOpInterface(op)
        print(iface.inferReturnTypes())

        # CHECK: [Type(i32), Type(i64)]
        iface_static = InferTypeOpInterface(test.InferResultsOp)
        print(iface.inferReturnTypes())

        assert isinstance(iface.opview, test.InferResultsOp)
        assert iface.opview == iface.operation.opview

        try:
            iface_static.opview
        except TypeError:
            pass
        else:
            assert False, (
                "not expected to be able to obtain an opview from a static" " interface"
            )

        try:
            InferTypeOpInterface(dummy)
        except ValueError:
            pass
        else:
            assert False, "not expected dummy op to implement the interface"

        try:
            InferTypeOpInterface(test.DummyOp)
        except ValueError:
            pass
        else:
            assert False, "not expected dummy op class to implement the interface"


# CHECK-LABEL: TEST: resultTypesDefinedByTraits
@run
def resultTypesDefinedByTraits():
    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            inferred = test.InferResultsOp()

            # CHECK: i32 i64
            print(inferred.single.type, inferred.doubled.type)

            same = test.SameOperandAndResultTypeOp([inferred.results[0]])
            # CHECK-COUNT-2: i32
            print(same.one.type)
            print(same.two.type)
            assert (
                typing.get_type_hints(test.SameOperandAndResultTypeOp.one.fget)[
                    "return"
                ]
                is OpResult
            )
            assert type(same.one) is OpResult

            first_type_attr = test.FirstAttrDeriveTypeAttrOp(
                inferred.results[1], TypeAttr.get(IndexType.get())
            )
            # CHECK-COUNT-2: index
            print(first_type_attr.one.type)
            print(first_type_attr.two.type)

            first_attr = test.FirstAttrDeriveAttrOp(FloatAttr.get(F32Type.get(), 3.14))
            # CHECK-COUNT-3: f32
            print(first_attr.one.type)
            print(first_attr.two.type)
            print(first_attr.three.type)

            implied = test.InferResultsImpliedOp()
            # CHECK: i32
            print(implied.integer.type)
            # CHECK: f64
            print(implied.flt.type)
            # CHECK: index
            print(implied.index.type)

            # provide the result types to avoid inferring them
            f64 = F64Type.get()
            no_imply = test.InferResultsImpliedOp(results=[f64, f64, f64])
            # CHECK-COUNT-3: f64
            print(no_imply.integer.type, no_imply.flt.type, no_imply.index.type)

            no_infer = test.InferResultsOp(results=[F32Type.get(), IndexType.get()])
            # CHECK: f32 index
            print(no_infer.single.type, no_infer.doubled.type)


# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
    with Context() as ctx, Location.unknown():
        module = Module.create()
        with InsertionPoint(module.body):
            op1 = test.OptionalOperandOp()
            # CHECK: op1.input is None: True
            print(f"op1.input is None: {op1.input is None}")
            assert (
                typing.get_type_hints(test.OptionalOperandOp.input.fget)["return"]
                is Optional[Value]
            )
            assert (
                typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
                == OpResult[IntegerType]
            )
            assert type(op1.result) is OpResult

            op2 = test.OptionalOperandOp(input=op1)
            # CHECK: op2.input is None: False
            print(f"op2.input is None: {op2.input is None}")


# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
    with Context() as ctx, Location.unknown():
        a = TestAttr.get()
        # CHECK: #python_test.test_attr
        print(a)

        # CHECK: python_test.custom_attributed_op  {
        # CHECK: #python_test.test_attr
        # CHECK: }
        op2 = test.CustomAttributedOp(a)
        print(f"{op2}")

        # CHECK: #python_test.test_attr
        print(f"{op2.test_attr}")

        # CHECK: TestAttr(#python_test.test_attr)
        print(repr(op2.test_attr))

        # The following cast must not assert.
        b = TestAttr(a)

        unit = UnitAttr.get()
        try:
            TestAttr(unit)
        except ValueError as e:
            assert "Cannot cast attribute to TestAttr" in str(e)
        else:
            raise

        # The following must trigger a TypeError from our adaptors and must not
        # crash.
        try:
            TestAttr(42)
        except TypeError as e:
            assert (
                "__init__(): incompatible function arguments. The following argument types are supported"
                in str(e)
            )
            assert (
                "__init__(self, cast_from_attr: mlir._mlir_libs._mlir.ir.Attribute) -> None"
                in str(e)
            )
            assert (
                "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestAttr, int"
                in str(e)
            )
        else:
            raise

        # The following must trigger a TypeError from pybind (therefore, not
        # checking its message) and must not crash.
        try:
            TestAttr(42, 56)
        except TypeError:
            pass
        else:
            raise


@run
def testCustomType():
    with Context() as ctx:
        a = TestType.get()
        # CHECK: !python_test.test_type
        print(a)

        # The following cast must not assert.
        b = TestType(a)
        # Instance custom types should have typeids
        assert isinstance(b.typeid, TypeID)

        i8 = IntegerType.get_signless(8)
        try:
            TestType(i8)
        except ValueError as e:
            assert "Cannot cast type to TestType" in str(e)
        else:
            raise

        # The following must trigger a TypeError from our adaptors and must not
        # crash.
        try:
            TestType(42)
        except TypeError as e:
            assert (
                "__init__(): incompatible function arguments. The following argument types are supported"
                in str(e)
            )
            assert (
                "__init__(self, cast_from_type: mlir._mlir_libs._mlir.ir.Type) -> None"
                in str(e)
            )
            assert (
                "Invoked with types: mlir._mlir_libs._mlirPythonTestNanobind.TestType, int"
                in str(e)
            )
        else:
            raise

        # The following must trigger a TypeError from pybind (therefore, not
        # checking its message) and must not crash.
        try:
            TestType(42, 56)
        except TypeError:
            pass
        else:
            raise


@run
# CHECK-LABEL: TEST: testValue
def testValue():
    # Check that Value is a generic class at runtime.
    assert hasattr(Value, "__class_getitem__")


@run
# CHECK-LABEL: TEST: testTensorValue
def testTensorValue():
    with Context() as ctx, Location.unknown():
        i8 = IntegerType.get_signless(8)

        class Tensor(TestTensorValue):
            def __str__(self):
                return super().__str__().replace("Value", "Tensor")

        module = Module.create()
        with InsertionPoint(module.body):
            t = tensor.EmptyOp([10, 10], i8).result

            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
            print(Value(t))

            tt = Tensor(t)
            # CHECK: Tensor(%{{.*}} = tensor.empty() : tensor<10x10xi8>)
            print(tt)

            # CHECK: False
            print(tt.is_null())

            # Classes of custom types that inherit from concrete types should have
            # static_typeid
            assert isinstance(TestIntegerRankedTensorType.static_typeid, TypeID)
            # And it should be equal to the in-tree concrete type
            assert TestIntegerRankedTensorType.static_typeid == t.type.typeid

            d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result
            # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>)
            print(d)
            # CHECK: TestTensorValue
            print(repr(d))


# CHECK-LABEL: TEST: inferReturnTypeComponents
@run
def inferReturnTypeComponents():
    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        i32 = IntegerType.get_signless(32)
        with InsertionPoint(module.body):
            resultType = UnrankedTensorType.get(i32)
            operandTypes = [
                RankedTensorType.get([1, 3, 10, 10], i32),
                UnrankedTensorType.get(i32),
            ]
            f = func.FuncOp(
                "test_inferReturnTypeComponents", (operandTypes, [resultType])
            )
            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
            with InsertionPoint(entry_block):
                ranked_op = test.InferShapedTypeComponentsOp(
                    resultType, entry_block.arguments[0]
                )
                unranked_op = test.InferShapedTypeComponentsOp(
                    resultType, entry_block.arguments[1]
                )

        # CHECK: has rank: True
        # CHECK: rank: 4
        # CHECK: element type: i32
        # CHECK: shape: [1, 3, 10, 10]
        iface = InferShapedTypeOpInterface(ranked_op)
        shaped_type_components = iface.inferReturnTypeComponents(
            operands=[ranked_op.operand]
        )[0]
        print("has rank:", shaped_type_components.has_rank)
        print("rank:", shaped_type_components.rank)
        print("element type:", shaped_type_components.element_type)
        print("shape:", shaped_type_components.shape)

        # CHECK: has rank: False
        # CHECK: rank: None
        # CHECK: element type: i32
        # CHECK: shape: None
        iface = InferShapedTypeOpInterface(unranked_op)
        shaped_type_components = iface.inferReturnTypeComponents(
            operands=[unranked_op.operand]
        )[0]
        print("has rank:", shaped_type_components.has_rank)
        print("rank:", shaped_type_components.rank)
        print("element type:", shaped_type_components.element_type)
        print("shape:", shaped_type_components.shape)


# CHECK-LABEL: TEST: testCustomTypeTypeCaster
@run
def testCustomTypeTypeCaster():
    with Context() as ctx, Location.unknown():
        a = TestType.get()
        assert a.typeid is not None

        b = Type.parse("!python_test.test_type")
        # CHECK: !python_test.test_type
        print(b)
        # CHECK: TestType(!python_test.test_type)
        print(repr(b))

        c = TestIntegerRankedTensorType.get([10, 10], 5)
        # CHECK: tensor<10x10xi5>
        print(c)
        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
        print(repr(c))

        # CHECK: Type caster is already registered
        try:

            @register_type_caster(c.typeid)
            def type_caster(pytype):
                return TestIntegerRankedTensorType(pytype)

        except RuntimeError as e:
            print(e)

        # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module.
        # So this one replaces that one (successfully). And then just to be sure we restore the original caster below.
        @register_type_caster(c.typeid, replace=True)
        def type_caster(pytype):
            return RankedTensorType(pytype)

        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
        # CHECK: tensor<10x10xi5>
        print(d.type)
        # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>)
        print("ranked tensor type", repr(d.type))

        @register_type_caster(c.typeid, replace=True)
        def type_caster(pytype):
            return TestIntegerRankedTensorType(pytype)

        d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result
        # CHECK: tensor<10x10xi5>
        print(d.type)
        # CHECK: TestIntegerRankedTensorType(tensor<10x10xi5>)
        print(repr(d.type))


# CHECK-LABEL: TEST: testInferTypeOpInterface
@run
def testInferTypeOpInterface():
    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            i64 = IntegerType.get_signless(64)
            zero = arith.ConstantOp(i64, 0)

            one_operand = test.InferResultsVariadicInputsOp(single=zero, doubled=None)
            # CHECK: i32
            print(one_operand.result.type)

            two_operands = test.InferResultsVariadicInputsOp(single=zero, doubled=zero)
            # CHECK: f32
            print(two_operands.result.type)

            assert (
                typing.get_type_hints(test.infer_results_variadic_inputs_op)["return"]
                is OpResult
            )
            assert (
                type(test.infer_results_variadic_inputs_op(single=zero, doubled=zero))
                is OpResult
            )


# CHECK-LABEL: TEST: testVariadicOperandAccess
@run
def testVariadicOperandAccess():
    def values(lst):
        return [str(e) for e in lst]

    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            i32 = IntegerType.get_signless(32)
            zero = arith.ConstantOp(i32, 0)
            one = arith.ConstantOp(i32, 1)
            two = arith.ConstantOp(i32, 2)
            three = arith.ConstantOp(i32, 3)
            four = arith.ConstantOp(i32, 4)

            variadic_operands = test.SameVariadicOperandSizeOp(
                [zero, one], two, [three, four]
            )
            # CHECK: OpResult(%{{.*}} = arith.constant 2 : i32)
            print(variadic_operands.non_variadic)
            assert (
                typing.get_type_hints(test.SameVariadicOperandSizeOp.non_variadic.fget)[
                    "return"
                ]
                is Value
            )
            assert type(variadic_operands.non_variadic) is OpResult

            # CHECK: ['OpResult(%{{.*}} = arith.constant 0 : i32)', 'OpResult(%{{.*}} = arith.constant 1 : i32)']
            print(values(variadic_operands.variadic1))
            assert (
                typing.get_type_hints(test.SameVariadicOperandSizeOp.variadic1.fget)[
                    "return"
                ]
                is OpOperandList
            )
            assert type(variadic_operands.variadic1) is OpOperandList

            # CHECK: ['OpResult(%{{.*}} = arith.constant 3 : i32)', 'OpResult(%{{.*}} = arith.constant 4 : i32)']
            print(values(variadic_operands.variadic2))
            assert type(variadic_operands.variadic2) is OpOperandList

            assert (
                typing.get_type_hints(test.same_variadic_operand)["return"]
                is test.SameVariadicOperandSizeOp
            )
            assert (
                type(test.same_variadic_operand([zero, one], two, [three, four]))
                is test.SameVariadicOperandSizeOp
            )


# CHECK-LABEL: TEST: testVariadicResultAccess
@run
def testVariadicResultAccess():
    def types(lst):
        return [e.type for e in lst]

    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            i = [IntegerType.get_signless(k) for k in range(7)]

            # Test Variadic-Fixed-Variadic
            op = test.SameVariadicResultSizeOpVFV([i[0], i[1]], i[2], [i[3], i[4]])
            # CHECK: i2
            print(op.non_variadic.type)
            # CHECK: [IntegerType(i0), IntegerType(i1)]
            print(types(op.variadic1))
            # CHECK: [IntegerType(i3), IntegerType(i4)]
            print(types(op.variadic2))

            assert (
                typing.get_type_hints(test.same_variadic_result_vfv)["return"]
                == Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
            )
            assert (
                type(test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]))
                is OpResultList
            )

            #  Test Variadic-Variadic-Variadic
            op = test.SameVariadicResultSizeOpVVV(
                [i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
            )
            # CHECK: [IntegerType(i0), IntegerType(i1)]
            print(types(op.variadic1))
            # CHECK: [IntegerType(i2), IntegerType(i3)]
            print(types(op.variadic2))
            # CHECK: [IntegerType(i4), IntegerType(i5)]
            print(types(op.variadic3))

            #  Test Fixed-Fixed-Variadic
            op = test.SameVariadicResultSizeOpFFV(i[0], i[1], [i[2], i[3], i[4]])
            # CHECK: i0
            print(op.non_variadic1.type)
            # CHECK: i1
            print(op.non_variadic2.type)
            # CHECK: [IntegerType(i2), IntegerType(i3), IntegerType(i4)]
            print(types(op.variadic))
            assert (
                typing.get_type_hints(test.SameVariadicResultSizeOpFFV.variadic.fget)[
                    "return"
                ]
                is OpResultList
            )
            assert type(op.variadic) is OpResultList

            #  Test Variadic-Variadic-Fixed
            op = test.SameVariadicResultSizeOpVVF(
                [i[0], i[1], i[2]], [i[3], i[4], i[5]], i[6]
            )
            # CHECK: [IntegerType(i0), IntegerType(i1), IntegerType(i2)]
            print(types(op.variadic1))
            # CHECK: [IntegerType(i3), IntegerType(i4), IntegerType(i5)]
            print(types(op.variadic2))
            # CHECK: i6
            print(op.non_variadic.type)

            # Test Fixed-Variadic-Fixed-Variadic-Fixed
            op = test.SameVariadicResultSizeOpFVFVF(
                i[0], [i[1], i[2]], i[3], [i[4], i[5]], i[6]
            )
            # CHECK: i0
            print(op.non_variadic1.type)
            # CHECK: [IntegerType(i1), IntegerType(i2)]
            print(types(op.variadic1))
            # CHECK: i3
            print(op.non_variadic2.type)
            # CHECK: [IntegerType(i4), IntegerType(i5)]
            print(types(op.variadic2))
            # CHECK: i6
            print(op.non_variadic3.type)

            # Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 0
            op = test.SameVariadicResultSizeOpFVFVF(i[0], [], i[1], [], i[2])
            # CHECK: i0
            print(op.non_variadic1.type)
            # CHECK: []
            print(types(op.variadic1))
            # CHECK: i1
            print(op.non_variadic2.type)
            # CHECK: []
            print(types(op.variadic2))
            # CHECK: i2
            print(op.non_variadic3.type)

            # Test Fixed-Variadic-Fixed-Variadic-Fixed - Variadic group size 1
            op = test.SameVariadicResultSizeOpFVFVF(i[0], [i[1]], i[2], [i[3]], i[4])
            # CHECK: i0
            print(op.non_variadic1.type)
            # CHECK: [IntegerType(i1)]
            print(types(op.variadic1))
            # CHECK: i2
            print(op.non_variadic2.type)
            # CHECK: [IntegerType(i3)]
            print(types(op.variadic2))
            # CHECK: i4
            print(op.non_variadic3.type)

            assert (
                typing.get_type_hints(test.results_variadic)["return"]
                == Union[OpResult, OpResultList, test.ResultsVariadicOp]
            )
            assert type(test.results_variadic([i[0]])) is OpResult
            op_res_variadic = test.ResultsVariadicOp([i[0]])
            assert (
                typing.get_type_hints(test.ResultsVariadicOp.res.fget)["return"]
                is OpResultList
            )
            assert type(op_res_variadic.res) is OpResultList


# CHECK-LABEL: TEST: testVariadicAndNormalRegionOp
@run
def testVariadicAndNormalRegionOp():
    with Context() as ctx, Location.unknown(ctx):
        module = Module.create()
        with InsertionPoint(module.body):
            region_op = test.VariadicAndNormalRegionOp(2)
            assert (
                typing.get_type_hints(test.VariadicAndNormalRegionOp.region.fget)[
                    "return"
                ]
                is Region
            )
            assert type(region_op.region) is Region
            assert (
                typing.get_type_hints(test.VariadicAndNormalRegionOp.variadic.fget)[
                    "return"
                ]
                is RegionSequence
            )
            assert type(region_op.variadic) is RegionSequence

            assert isinstance(region_op.opview, OpView)
            assert isinstance(region_op.operation.opview, OpView)
