blob: 6afd0815de2adc3e82a3f5a3d329efbbc2254943 [file] [log] [blame]
Stella Laurenzof13893f2021-05-09 18:09:09 -07001//===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Dialects.h"
10#include "mlir-c/Dialect/SparseTensor.h"
11#include "mlir-c/IR.h"
12#include "mlir/Bindings/Python/PybindAdaptors.h"
13
14namespace py = pybind11;
15using namespace llvm;
16using namespace mlir;
17using namespace mlir::python::adaptors;
18
19void mlir::python::populateDialectSparseTensorSubmodule(
20 py::module m, const py::module &irModule) {
21 auto attributeClass = irModule.attr("Attribute");
22
Sean Silva8dca9532021-09-14 21:55:54 +000023 py::enum_<MlirSparseTensorDimLevelType>(m, "DimLevelType", py::module_local())
Stella Laurenzof13893f2021-05-09 18:09:09 -070024 .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE)
25 .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED)
26 .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON);
27
28 mlir_attribute_subclass(m, "EncodingAttr",
29 mlirAttributeIsASparseTensorEncodingAttr,
30 attributeClass)
31 .def_classmethod(
32 "get",
33 [](py::object cls,
34 std::vector<MlirSparseTensorDimLevelType> dimLevelTypes,
35 llvm::Optional<MlirAffineMap> dimOrdering, int pointerBitWidth,
36 int indexBitWidth, MlirContext context) {
37 return cls(mlirSparseTensorEncodingAttrGet(
38 context, dimLevelTypes.size(), dimLevelTypes.data(),
39 dimOrdering ? *dimOrdering : MlirAffineMap{nullptr},
40 pointerBitWidth, indexBitWidth));
41 },
42 py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"),
43 py::arg("pointer_bit_width"), py::arg("index_bit_width"),
44 py::arg("context") = py::none(),
45 "Gets a sparse_tensor.encoding from parameters.")
46 .def_property_readonly(
47 "dim_level_types",
48 [](MlirAttribute self) {
49 std::vector<MlirSparseTensorDimLevelType> ret;
50 for (int i = 0,
51 e = mlirSparseTensorEncodingGetNumDimLevelTypes(self);
52 i < e; ++i)
53 ret.push_back(
54 mlirSparseTensorEncodingAttrGetDimLevelType(self, i));
55 return ret;
56 })
57 .def_property_readonly(
58 "dim_ordering",
59 [](MlirAttribute self) -> llvm::Optional<MlirAffineMap> {
60 MlirAffineMap ret =
61 mlirSparseTensorEncodingAttrGetDimOrdering(self);
62 if (mlirAffineMapIsNull(ret))
63 return {};
64 return ret;
65 })
66 .def_property_readonly(
67 "pointer_bit_width",
68 [](MlirAttribute self) {
69 return mlirSparseTensorEncodingAttrGetPointerBitWidth(self);
70 })
71 .def_property_readonly("index_bit_width", [](MlirAttribute self) {
72 return mlirSparseTensorEncodingAttrGetIndexBitWidth(self);
73 });
74}