blob: 380aa36d7782453e83b11a374bc50a95cb25cae6 [file] [log] [blame]
//===- IRTypes.cpp - Exports builtin and standard types -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "IRModule.h"
#include "PybindUtils.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;
using llvm::SmallVector;
using llvm::Twine;
namespace {
/// Checks whether the given type is an integer or float type.
static int mlirTypeIsAIntegerOrFloat(MlirType type) {
return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
}
class PyIntegerType : public PyConcreteType<PyIntegerType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
static constexpr const char *pyClassName = "IntegerType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_signless",
[](unsigned width, DefaultingPyMlirContext context) {
MlirType t = mlirIntegerTypeGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
"Create a signless integer type");
c.def_static(
"get_signed",
[](unsigned width, DefaultingPyMlirContext context) {
MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
"Create a signed integer type");
c.def_static(
"get_unsigned",
[](unsigned width, DefaultingPyMlirContext context) {
MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
return PyIntegerType(context->getRef(), t);
},
py::arg("width"), py::arg("context") = py::none(),
"Create an unsigned integer type");
c.def_property_readonly(
"width",
[](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
"Returns the width of the integer type");
c.def_property_readonly(
"is_signless",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSignless(self);
},
"Returns whether this is a signless integer");
c.def_property_readonly(
"is_signed",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsSigned(self);
},
"Returns whether this is a signed integer");
c.def_property_readonly(
"is_unsigned",
[](PyIntegerType &self) -> bool {
return mlirIntegerTypeIsUnsigned(self);
},
"Returns whether this is an unsigned integer");
}
};
/// Index Type subclass - IndexType.
class PyIndexType : public PyConcreteType<PyIndexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
static constexpr const char *pyClassName = "IndexType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirIndexTypeGet(context->get());
return PyIndexType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a index type.");
}
};
/// Floating Point Type subclass - BF16Type.
class PyBF16Type : public PyConcreteType<PyBF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
static constexpr const char *pyClassName = "BF16Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirBF16TypeGet(context->get());
return PyBF16Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a bf16 type.");
}
};
/// Floating Point Type subclass - F16Type.
class PyF16Type : public PyConcreteType<PyF16Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
static constexpr const char *pyClassName = "F16Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirF16TypeGet(context->get());
return PyF16Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f16 type.");
}
};
/// Floating Point Type subclass - F32Type.
class PyF32Type : public PyConcreteType<PyF32Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
static constexpr const char *pyClassName = "F32Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirF32TypeGet(context->get());
return PyF32Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f32 type.");
}
};
/// Floating Point Type subclass - F64Type.
class PyF64Type : public PyConcreteType<PyF64Type> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
static constexpr const char *pyClassName = "F64Type";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirF64TypeGet(context->get());
return PyF64Type(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a f64 type.");
}
};
/// None Type subclass - NoneType.
class PyNoneType : public PyConcreteType<PyNoneType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
static constexpr const char *pyClassName = "NoneType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](DefaultingPyMlirContext context) {
MlirType t = mlirNoneTypeGet(context->get());
return PyNoneType(context->getRef(), t);
},
py::arg("context") = py::none(), "Create a none type.");
}
};
/// Complex Type subclass - ComplexType.
class PyComplexType : public PyConcreteType<PyComplexType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
static constexpr const char *pyClassName = "ComplexType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elementType) {
// The element must be a floating point or integer scalar type.
if (mlirTypeIsAIntegerOrFloat(elementType)) {
MlirType t = mlirComplexTypeGet(elementType);
return PyComplexType(elementType.getContext(), t);
}
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
},
"Create a complex type");
c.def_property_readonly(
"element_type",
[](PyComplexType &self) -> PyType {
MlirType t = mlirComplexTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns element type.");
}
};
class PyShapedType : public PyConcreteType<PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
static constexpr const char *pyClassName = "ShapedType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_property_readonly(
"element_type",
[](PyShapedType &self) {
MlirType t = mlirShapedTypeGetElementType(self);
return PyType(self.getContext(), t);
},
"Returns the element type of the shaped type.");
c.def_property_readonly(
"has_rank",
[](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
"Returns whether the given shaped type is ranked.");
c.def_property_readonly(
"rank",
[](PyShapedType &self) {
self.requireHasRank();
return mlirShapedTypeGetRank(self);
},
"Returns the rank of the given ranked shaped type.");
c.def_property_readonly(
"has_static_shape",
[](PyShapedType &self) -> bool {
return mlirShapedTypeHasStaticShape(self);
},
"Returns whether the given shaped type has a static shape.");
c.def(
"is_dynamic_dim",
[](PyShapedType &self, intptr_t dim) -> bool {
self.requireHasRank();
return mlirShapedTypeIsDynamicDim(self, dim);
},
py::arg("dim"),
"Returns whether the dim-th dimension of the given shaped type is "
"dynamic.");
c.def(
"get_dim_size",
[](PyShapedType &self, intptr_t dim) {
self.requireHasRank();
return mlirShapedTypeGetDimSize(self, dim);
},
py::arg("dim"),
"Returns the dim-th dimension of the given ranked shaped type.");
c.def_static(
"is_dynamic_size",
[](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
py::arg("dim_size"),
"Returns whether the given dimension size indicates a dynamic "
"dimension.");
c.def(
"is_dynamic_stride_or_offset",
[](PyShapedType &self, int64_t val) -> bool {
self.requireHasRank();
return mlirShapedTypeIsDynamicStrideOrOffset(val);
},
py::arg("dim_size"),
"Returns whether the given value is used as a placeholder for dynamic "
"strides and offsets in shaped types.");
c.def_property_readonly(
"shape",
[](PyShapedType &self) {
self.requireHasRank();
std::vector<int64_t> shape;
int64_t rank = mlirShapedTypeGetRank(self);
shape.reserve(rank);
for (int64_t i = 0; i < rank; ++i)
shape.push_back(mlirShapedTypeGetDimSize(self, i));
return shape;
},
"Returns the shape of the ranked shaped type as a list of integers.");
}
private:
void requireHasRank() {
if (!mlirShapedTypeHasRank(*this)) {
throw SetPyError(
PyExc_ValueError,
"calling this method requires that the type has a rank.");
}
}
};
/// Vector Type subclass - VectorType.
class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
static constexpr const char *pyClassName = "VectorType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
DefaultingPyLocation loc) {
MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
elementType);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point or integer type.");
}
return PyVectorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
"Create a vector type");
}
};
/// Ranked Tensor Type subclass - RankedTensorType.
class PyRankedTensorType
: public PyConcreteType<PyRankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
static constexpr const char *pyClassName = "RankedTensorType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
llvm::Optional<PyAttribute> &encodingAttr,
DefaultingPyLocation loc) {
MlirType t = mlirRankedTensorTypeGetChecked(
loc, shape.size(), shape.data(), elementType,
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyRankedTensorType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
py::arg("encoding") = py::none(), py::arg("loc") = py::none(),
"Create a ranked tensor type");
c.def_property_readonly(
"encoding",
[](PyRankedTensorType &self) -> llvm::Optional<PyAttribute> {
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
if (mlirAttributeIsNull(encoding))
return llvm::None;
return PyAttribute(self.getContext(), encoding);
});
}
};
/// Unranked Tensor Type subclass - UnrankedTensorType.
class PyUnrankedTensorType
: public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
static constexpr const char *pyClassName = "UnrankedTensorType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elementType, DefaultingPyLocation loc) {
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedTensorType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("loc") = py::none(),
"Create a unranked tensor type");
}
};
/// Ranked MemRef Type subclass - MemRefType.
class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
static constexpr const char *pyClassName = "MemRefType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<int64_t> shape, PyType &elementType,
PyAttribute *layout, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
MlirAttribute memSpaceAttr =
memorySpace ? *memorySpace : mlirAttributeGetNull();
MlirType t =
mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
shape.data(), layoutAttr, memSpaceAttr);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyMemRefType(elementType.getContext(), t);
},
py::arg("shape"), py::arg("element_type"),
py::arg("layout") = py::none(), py::arg("memory_space") = py::none(),
py::arg("loc") = py::none(), "Create a memref type")
.def_property_readonly(
"layout",
[](PyMemRefType &self) -> PyAttribute {
MlirAttribute layout = mlirMemRefTypeGetLayout(self);
return PyAttribute(self.getContext(), layout);
},
"The layout of the MemRef type.")
.def_property_readonly(
"affine_map",
[](PyMemRefType &self) -> PyAffineMap {
MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
return PyAffineMap(self.getContext(), map);
},
"The layout of the MemRef type as an affine map.")
.def_property_readonly(
"memory_space",
[](PyMemRefType &self) -> PyAttribute {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
return PyAttribute(self.getContext(), a);
},
"Returns the memory space of the given MemRef type.");
}
};
/// Unranked MemRef Type subclass - UnrankedMemRefType.
class PyUnrankedMemRefType
: public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
static constexpr const char *pyClassName = "UnrankedMemRefType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](PyType &elementType, PyAttribute *memorySpace,
DefaultingPyLocation loc) {
MlirAttribute memSpaceAttr = {};
if (memorySpace)
memSpaceAttr = *memorySpace;
MlirType t =
mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
// TODO: Rework error reporting once diagnostic engine is exposed
// in C API.
if (mlirTypeIsNull(t)) {
throw SetPyError(
PyExc_ValueError,
Twine("invalid '") +
py::repr(py::cast(elementType)).cast<std::string>() +
"' and expected floating point, integer, vector or "
"complex "
"type.");
}
return PyUnrankedMemRefType(elementType.getContext(), t);
},
py::arg("element_type"), py::arg("memory_space"),
py::arg("loc") = py::none(), "Create a unranked memref type")
.def_property_readonly(
"memory_space",
[](PyUnrankedMemRefType &self) -> PyAttribute {
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
return PyAttribute(self.getContext(), a);
},
"Returns the memory space of the given Unranked MemRef type.");
}
};
/// Tuple Type subclass - TupleType.
class PyTupleType : public PyConcreteType<PyTupleType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
static constexpr const char *pyClassName = "TupleType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get_tuple",
[](py::list elementList, DefaultingPyMlirContext context) {
intptr_t num = py::len(elementList);
// Mapping py::list to SmallVector.
SmallVector<MlirType, 4> elements;
for (auto element : elementList)
elements.push_back(element.cast<PyType>());
MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
return PyTupleType(context->getRef(), t);
},
py::arg("elements"), py::arg("context") = py::none(),
"Create a tuple type");
c.def(
"get_type",
[](PyTupleType &self, intptr_t pos) -> PyType {
MlirType t = mlirTupleTypeGetType(self, pos);
return PyType(self.getContext(), t);
},
py::arg("pos"), "Returns the pos-th type in the tuple type.");
c.def_property_readonly(
"num_types",
[](PyTupleType &self) -> intptr_t {
return mlirTupleTypeGetNumTypes(self);
},
"Returns the number of types contained in a tuple.");
}
};
/// Function type.
class PyFunctionType : public PyConcreteType<PyFunctionType> {
public:
static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
static constexpr const char *pyClassName = "FunctionType";
using PyConcreteType::PyConcreteType;
static void bindDerived(ClassTy &c) {
c.def_static(
"get",
[](std::vector<PyType> inputs, std::vector<PyType> results,
DefaultingPyMlirContext context) {
SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
inputsRaw.data(), resultsRaw.size(),
resultsRaw.data());
return PyFunctionType(context->getRef(), t);
},
py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
"Gets a FunctionType from a list of input and result types");
c.def_property_readonly(
"inputs",
[](PyFunctionType &self) {
MlirType t = self;
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
++i) {
types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
}
return types;
},
"Returns the list of input types in the FunctionType.");
c.def_property_readonly(
"results",
[](PyFunctionType &self) {
auto contextRef = self.getContext();
py::list types;
for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
++i) {
types.append(
PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
}
return types;
},
"Returns the list of result types in the FunctionType.");
}
};
} // namespace
void mlir::python::populateIRTypes(py::module &m) {
PyIntegerType::bind(m);
PyIndexType::bind(m);
PyBF16Type::bind(m);
PyF16Type::bind(m);
PyF32Type::bind(m);
PyF64Type::bind(m);
PyNoneType::bind(m);
PyComplexType::bind(m);
PyShapedType::bind(m);
PyVectorType::bind(m);
PyRankedTensorType::bind(m);
PyUnrankedTensorType::bind(m);
PyMemRefType::bind(m);
PyUnrankedMemRefType::bind(m);
PyTupleType::bind(m);
PyFunctionType::bind(m);
}