[MLIR][Python] enable ptr dialect bindings
diff --git a/mlir/include/mlir-c/Dialect/PtrDialect.h b/mlir/include/mlir-c/Dialect/PtrDialect.h new file mode 100644 index 0000000..3df7094 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/PtrDialect.h
@@ -0,0 +1,41 @@ +//===- PtrDialect.h - C interface for the Ptr dialect -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_PTR_H +#define MLIR_C_DIALECT_PTR_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Ptr, ptr); + +//===----------------------------------------------------------------------===// +// MemorySpaceAttrInterface API. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +/// Checks if the given type is a Ptr type. +MLIR_CAPI_EXPORTED bool mlirPtrTypeIsAPtrType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPtrGetPtrType(MlirAttribute memorySpace); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_PTR_H
diff --git a/mlir/lib/Bindings/Python/DialectPtr.cpp b/mlir/lib/Bindings/Python/DialectPtr.cpp new file mode 100644 index 0000000..8dc99c1 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectPtr.cpp
@@ -0,0 +1,41 @@ +//===- DialectPtr.cpp - Pybind module for Ptr dialect API support ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "NanobindUtils.h" + +#include "mlir-c/Dialect/PtrDialect.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; + +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectPTRSubmodule(nanobind::module_ &m) { + mlir_type_subclass(m, "PtrType", mlirPtrTypeIsAPtrType) + .def_classmethod( + "get", + [](const nb::object &cls, MlirAttribute memorySpace) { + return cls(mlirPtrGetPtrType(memorySpace)); + }, + "Gets an instance of PtrType with memory_space in the same context", + nb::arg("cls"), nb::arg("memory_space")); +} + +NB_MODULE(_mlirDialectsPTR, m) { + m.doc() = "MLIR PTR Dialect"; + + populateDialectPTRSubmodule(m); +}
diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index bb1fdf8..9462ad2 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt
@@ -278,3 +278,12 @@ MLIRCAPIIR MLIRSMT ) + +add_mlir_upstream_c_api_library(MLIRCAPIPtrDialect + PtrDialect.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRPtrDialect +) \ No newline at end of file
diff --git a/mlir/lib/CAPI/Dialect/PtrDialect.cpp b/mlir/lib/CAPI/Dialect/PtrDialect.cpp new file mode 100644 index 0000000..a8f06f4 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/PtrDialect.cpp
@@ -0,0 +1,39 @@ +//===- PtrDialect.cpp - C interface for the Ptr dialect -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PtrDialect.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "ptr-dialect-capi" + +using namespace mlir; +using namespace ptr; + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Ptr, ptr, mlir::ptr::PtrDialect) + +bool mlirPtrTypeIsAPtrType(MlirType type) { + return llvm::isa<ptr::PtrType>(unwrap(type)); +} + +MlirType mlirPtrGetPtrType(MlirAttribute memorySpace) { + MemorySpaceAttrInterface memorySpaceAttr = + dyn_cast<MemorySpaceAttrInterface>(unwrap(memorySpace)); + if (!memorySpaceAttr) { + LLVM_DEBUG(llvm::dbgs() + << "expected memory-space to be MemorySpaceAttrInterface"); + return {nullptr}; + } + return wrap(ptr::PtrType::get(memorySpaceAttr)); +}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 51c7576..112c8e9 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt
@@ -516,6 +516,15 @@ GEN_ENUM_BINDINGS ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/PtrOps.td + SOURCES dialects/ptr.py + DIALECT_NAME ptr + GEN_ENUM_BINDINGS +) + ################################################################################ # Python extensions. # The sources for these are all in lib/Bindings/Python, but since they have to @@ -579,7 +588,7 @@ MLIRCAPIRegisterEverything ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Nanobind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -593,7 +602,7 @@ MLIRCAPILinalg ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Nanobind MODULE_NAME _mlirDialectsGPU ADD_TO_PARENT MLIRPythonSources.Dialects.gpu ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -607,7 +616,7 @@ MLIRCAPIGPU ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Nanobind MODULE_NAME _mlirDialectsLLVM ADD_TO_PARENT MLIRPythonSources.Dialects.llvm ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -623,7 +632,7 @@ MLIRCAPITarget ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Nanobind MODULE_NAME _mlirDialectsQuant ADD_TO_PARENT MLIRPythonSources.Dialects.quant ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -637,7 +646,7 @@ MLIRCAPIQuant ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Nanobind MODULE_NAME _mlirDialectsNVGPU ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -651,7 +660,7 @@ MLIRCAPINVGPU ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Nanobind MODULE_NAME _mlirDialectsPDL ADD_TO_PARENT MLIRPythonSources.Dialects.pdl ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -665,7 +674,7 @@ MLIRCAPIPDL ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Nanobind MODULE_NAME _mlirDialectsSparseTensor ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -679,7 +688,7 @@ MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Nanobind MODULE_NAME _mlirDialectsTransform ADD_TO_PARENT MLIRPythonSources.Dialects.transform ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -693,7 +702,7 @@ MLIRCAPITransformDialect ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Nanobind MODULE_NAME _mlirDialectsIRDL ADD_TO_PARENT MLIRPythonSources.Dialects.irdl ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -761,7 +770,7 @@ MLIRCAPILinalg ) -declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Nanobind MODULE_NAME _mlirDialectsSMT ADD_TO_PARENT MLIRPythonSources.Dialects.smt ROOT_DIR "${PYTHON_SOURCE_DIR}" @@ -778,6 +787,22 @@ MLIRCAPIExportSMTLIB ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Ptr.Nanobind + MODULE_NAME _mlirDialectsPtr + ADD_TO_PARENT MLIRPythonSources.Dialects.ptr + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectPtr.cpp + # Headers must be included explicitly so they are installed. + NanobindUtils.h + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIPtrDialect +) + declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
diff --git a/mlir/python/mlir/dialects/PtrOps.td b/mlir/python/mlir/dialects/PtrOps.td new file mode 100644 index 0000000..8bde942 --- /dev/null +++ b/mlir/python/mlir/dialects/PtrOps.td
@@ -0,0 +1,14 @@ +//===- PTROps.td - Entry point for PTR bindings ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BINDINGS_PYTHON_PTR_OPS +#define BINDINGS_PYTHON_PTR_OPS + +include "mlir/Dialect/Ptr/IR/PtrOps.td" + +#endif // BINDINGS_PYTHON_PTR_OPS
diff --git a/mlir/python/mlir/dialects/ptr.py b/mlir/python/mlir/dialects/ptr.py new file mode 100644 index 0000000..a837b5b --- /dev/null +++ b/mlir/python/mlir/dialects/ptr.py
@@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._ptr_ops_gen import * +from ._ptr_enum_gen import *
diff --git a/mlir/test/python/dialects/ptr.py b/mlir/test/python/dialects/ptr.py new file mode 100644 index 0000000..8dd4178 --- /dev/null +++ b/mlir/test/python/dialects/ptr.py
@@ -0,0 +1,23 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.dialects import ptr +from mlir.ir import Context, Location, Module, InsertionPoint, Attribute + + +def run(f): + print("\nTEST:", f.__name__) + with Context(), Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + f(module) + print(module) + assert module.operation.verify() + + +# CHECK-LABEL: TEST: test_smoke +@run +def test_smoke(_module): + null_ptr = Attribute.parse("#ptr.null : !ptr.ptr<#llvm.address_space<1>>") + null = ptr.constant(null_ptr) + # CHECK: %0 = ptr.constant #ptr.null : !ptr.ptr<#llvm.address_space<1>> + print(null)