blob: 80c965b2d0eb2cc4f89abecb75ad16554d42dbd9 [file] [log] [blame]
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster, register_value_caster
from ._mlir_libs import get_dialect_registry
# Convenience decorator for registering user-friendly Attribute builders.
def register_attribute_builder(kind, replace=False):
def decorator_builder(func):
AttrBuilder.insert(kind, func, replace=replace)
return func
return decorator_builder
@register_attribute_builder("AffineMapAttr")
def _affineMapAttr(x, context):
return AffineMapAttr.get(x)
@register_attribute_builder("BoolAttr")
def _boolAttr(x, context):
return BoolAttr.get(x, context=context)
@register_attribute_builder("DictionaryAttr")
def _dictAttr(x, context):
return DictAttr.get(x, context=context)
@register_attribute_builder("IndexAttr")
def _indexAttr(x, context):
return IntegerAttr.get(IndexType.get(context=context), x)
@register_attribute_builder("I1Attr")
def _i1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
@register_attribute_builder("I8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
@register_attribute_builder("I16Attr")
def _i16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@register_attribute_builder("I32Attr")
def _i32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(32, context=context), x)
@register_attribute_builder("I64Attr")
def _i64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
@register_attribute_builder("SI1Attr")
def _si1Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
@register_attribute_builder("SI8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
@register_attribute_builder("SI16Attr")
def _si16Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@register_attribute_builder("SI32Attr")
def _si32Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
@register_attribute_builder("SI64Attr")
def _si64Attr(x, context):
return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
@register_attribute_builder("UI1Attr")
def _ui1Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
@register_attribute_builder("UI8Attr")
def _i8Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
@register_attribute_builder("UI16Attr")
def _ui16Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
@register_attribute_builder("UI32Attr")
def _ui32Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
@register_attribute_builder("UI64Attr")
def _ui64Attr(x, context):
return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
@register_attribute_builder("F32Attr")
def _f32Attr(x, context):
return FloatAttr.get_f32(x, context=context)
@register_attribute_builder("F64Attr")
def _f64Attr(x, context):
return FloatAttr.get_f64(x, context=context)
@register_attribute_builder("StrAttr")
def _stringAttr(x, context):
return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolNameAttr")
def _symbolNameAttr(x, context):
return StringAttr.get(x, context=context)
@register_attribute_builder("SymbolRefAttr")
def _symbolRefAttr(x, context):
if isinstance(x, list):
return SymbolRefAttr.get(x, context=context)
else:
return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("FlatSymbolRefAttr")
def _flatSymbolRefAttr(x, context):
return FlatSymbolRefAttr.get(x, context=context)
@register_attribute_builder("UnitAttr")
def _unitAttr(x, context):
if x:
return UnitAttr.get(context=context)
else:
return None
@register_attribute_builder("ArrayAttr")
def _arrayAttr(x, context):
return ArrayAttr.get(x, context=context)
@register_attribute_builder("AffineMapArrayAttr")
def _affineMapArrayAttr(x, context):
return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
@register_attribute_builder("BoolArrayAttr")
def _boolArrayAttr(x, context):
return ArrayAttr.get([_boolAttr(v, context) for v in x])
@register_attribute_builder("DictArrayAttr")
def _dictArrayAttr(x, context):
return ArrayAttr.get([_dictAttr(v, context) for v in x])
@register_attribute_builder("FlatSymbolRefArrayAttr")
def _flatSymbolRefArrayAttr(x, context):
return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
@register_attribute_builder("I32ArrayAttr")
def _i32ArrayAttr(x, context):
return ArrayAttr.get([_i32Attr(v, context) for v in x])
@register_attribute_builder("I64ArrayAttr")
def _i64ArrayAttr(x, context):
return ArrayAttr.get([_i64Attr(v, context) for v in x])
@register_attribute_builder("I64SmallVectorArrayAttr")
def _i64SmallVectorArrayAttr(x, context):
return _i64ArrayAttr(x, context=context)
@register_attribute_builder("IndexListArrayAttr")
def _indexListArrayAttr(x, context):
return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
@register_attribute_builder("F32ArrayAttr")
def _f32ArrayAttr(x, context):
return ArrayAttr.get([_f32Attr(v, context) for v in x])
@register_attribute_builder("F64ArrayAttr")
def _f64ArrayAttr(x, context):
return ArrayAttr.get([_f64Attr(v, context) for v in x])
@register_attribute_builder("StrArrayAttr")
def _strArrayAttr(x, context):
return ArrayAttr.get([_stringAttr(v, context) for v in x])
@register_attribute_builder("SymbolRefArrayAttr")
def _symbolRefArrayAttr(x, context):
return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
@register_attribute_builder("DenseF32ArrayAttr")
def _denseF32ArrayAttr(x, context):
return DenseF32ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseF64ArrayAttr")
def _denseF64ArrayAttr(x, context):
return DenseF64ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI8ArrayAttr")
def _denseI8ArrayAttr(x, context):
return DenseI8ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI16ArrayAttr")
def _denseI16ArrayAttr(x, context):
return DenseI16ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI32ArrayAttr")
def _denseI32ArrayAttr(x, context):
return DenseI32ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseI64ArrayAttr")
def _denseI64ArrayAttr(x, context):
return DenseI64ArrayAttr.get(x, context=context)
@register_attribute_builder("DenseBoolArrayAttr")
def _denseBoolArrayAttr(x, context):
return DenseBoolArrayAttr.get(x, context=context)
@register_attribute_builder("TypeAttr")
def _typeAttr(x, context):
return TypeAttr.get(x, context=context)
@register_attribute_builder("TypeArrayAttr")
def _typeArrayAttr(x, context):
return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context)
@register_attribute_builder("MemRefTypeAttr")
def _memref_type_attr(x, context):
return _typeAttr(x, context)
try:
import numpy as np
@register_attribute_builder("F64ElementsAttr")
def _f64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.float64),
type=F64Type.get(context=context),
context=context,
)
@register_attribute_builder("I32ElementsAttr")
def _i32ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=IntegerType.get_signless(32, context=context),
context=context,
)
@register_attribute_builder("I64ElementsAttr")
def _i64ElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=IntegerType.get_signless(64, context=context),
context=context,
)
@register_attribute_builder("IndexElementsAttr")
def _indexElementsAttr(x, context):
return DenseElementsAttr.get(
np.array(x, dtype=np.int64),
type=IndexType.get(context=context),
context=context,
)
except ImportError:
pass