[mlir][python] Add utils for more pythonic context creation and registration management

Co-authored-by: Fabian Mora <fmora.dev@gmail.com
Co-authored-by: Oleksandr "Alex" Zinenko <git@ozinenko.com>
Co-authored-by: Tres <tpopp@users.noreply.github.com>
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 81299c7..877aa73 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -143,6 +143,10 @@
 MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context,
                                                         bool enable);
 
+/// Retrieve threading mode current value as controlled by
+/// mlirContextEnableMultithreading.
+MLIR_CAPI_EXPORTED bool mlirContextIsMultithreadingEnabled(MlirContext context);
+
 /// Eagerly loads all available dialects registered with a context, making
 /// them available for use for IR construction.
 MLIR_CAPI_EXPORTED void
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index d961482..002923b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2939,6 +2939,12 @@
              ss << pool.ptr;
              return ss.str();
            })
+      .def_prop_ro(
+          "is_multithreading_enabled",
+          [](PyMlirContext &self) {
+            return mlirContextIsMultithreadingEnabled(self.get());
+          },
+          "Returns true if multithreading is enabled for this context.")
       .def(
           "is_registered_operation",
           [](PyMlirContext &self, std::string &name) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fbc66bc..1cc555a 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -101,6 +101,10 @@
   return unwrap(context)->isOperationRegistered(unwrap(name));
 }
 
+bool mlirContextIsMultithreadingEnabled(MlirContext context) {
+  return unwrap(context)->isMultithreadingEnabled();
+}
+
 void mlirContextEnableMultithreading(MlirContext context, bool enable) {
   return unwrap(context)->enableMultithreading(enable);
 }
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index b2daabb..b4e0ab2 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -48,6 +48,13 @@
     runtime/*.py
 )
 
+declare_mlir_python_sources(MLIRPythonSources.Utils
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  ADD_TO_PARENT MLIRPythonSources
+  SOURCES
+    utils.py
+)
+
 declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources
   ROOT_DIR "${MLIR_SOURCE_DIR}/include"
   SOURCES_GLOB "mlir-c/*.h"
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 70bca3c..56b9f17 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -986,6 +986,7 @@
 class Context:
     current: ClassVar[Context] = ...  # read-only
     allow_unregistered_dialects: bool
+    is_multithreading_enabled: bool
     @staticmethod
     def _get_live_count() -> int: ...
     def _CAPICreate(self) -> object: ...
diff --git a/mlir/python/mlir/utils.py b/mlir/python/mlir/utils.py
new file mode 100644
index 0000000..c6e9b57
--- /dev/null
+++ b/mlir/python/mlir/utils.py
@@ -0,0 +1,211 @@
+from contextlib import contextmanager, nullcontext
+from functools import wraps
+from typing import (
+    Any,
+    Callable,
+    Concatenate,
+    Iterator,
+    Optional,
+    ParamSpec,
+    Sequence,
+    TypeVar,
+)
+
+from mlir import ir
+from mlir._mlir_libs import get_dialect_registry
+from mlir.dialects import func
+from mlir.dialects.transform import interpreter
+from mlir.passmanager import PassManager
+
+RT = TypeVar("RT")
+Param = ParamSpec("Param")
+
+
+@contextmanager
+def using_mlir_context(
+    *,
+    required_dialects: Optional[Sequence[str]] = None,
+    required_extension_operations: Optional[Sequence[str]] = None,
+    registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+    """Ensure a valid context exists by creating one if necessary.
+
+    NOTE: If values that are attached to a Context should outlive this
+          contextmanager, use caller_mlir_context!
+
+    This can be used as a function decorator or managed context in a with statement.
+    The context will throw an error if the required dialects have not been registered,
+    and a context is guaranteed to exist in this scope.
+
+    This only works on dialects and not dialect extensions currently.
+
+    Parameters
+    ------------
+        required_dialects:
+            Dialects that need to be registered in the context
+        required_extension_operations:
+            Required operations by their fully specified name. These are a proxy for detecting needed dialect extensions.
+        registration_funcs:
+            Functions that should be called to register all missing dialects/operations if they have not been registered.
+    """
+    dialects = required_dialects or []
+    extension_operations = required_extension_operations or []
+    registrations = registration_funcs or []
+    new_context = nullcontext if ir.Context.current else ir.Context
+    with new_context(), ir.Location.unknown():
+        context = ir.Context.current
+        # Attempt to disable multithreading. This could fail if currently being
+        # used in multiple threads. This must be done before checking for
+        # dialects or registering dialects as both will assert fail in a
+        # multithreaded situation.
+        multithreading = context.is_multithreading_enabled
+        if multithreading:
+            context.enable_multithreading(False)
+
+        def attempt_registration():
+            """Register everything from registration_funcs."""
+            nonlocal context, registrations
+
+            # Gather dialects and extensions then add them to the context.
+            registry = ir.DialectRegistry()
+            for rf in registrations:
+                rf(registry)
+
+            context.append_dialect_registry(registry)
+
+        # See if any dialects are missing, register if they are, and then assert they are all registered.
+        try:
+            for dialect in dialects:
+                # If the dialect is registered, continue checking
+                context.get_dialect_descriptor(dialect)
+        except Exception:
+            attempt_registration()
+
+        for dialect in dialects:
+            # If the dialect is registered, continue checking
+            assert context.get_dialect_descriptor(
+                dialect
+            ), f"required dialect {dialect} not registered by registration_funcs"
+
+        # See if any operations are missing and register if they are. We cannot
+        # assert the operations exist in the registry after for some reason.
+        #
+        # TODO: Make this work for dialect extensions specifically
+        for operation in extension_operations:
+            # If the operation is registered, attempt to register and then strongly assert it was added
+            if not context.is_registered_operation(operation):
+                attempt_registration()
+                break
+        for operation in extension_operations:
+            # First get the dialect descriptior which loads the dialect as a side effect
+            dialect = operation.split(".")[0]
+            assert context.get_dialect_descriptor(dialect), f"Never loaded {dialect}"
+            assert context.is_registered_operation(
+                operation
+            ), f"expected {operation} to be registered in its dialect"
+        context.enable_multithreading(multithreading)
+
+        # Context manager related yield
+        try:
+            yield
+        finally:
+            pass
+
+
+@contextmanager
+def caller_mlir_context(
+    *,
+    required_dialects: Optional[Sequence[str]] = None,
+    required_extension_operations: Optional[Sequence[str]] = None,
+    registration_funcs: Optional[Sequence[Callable[[ir.DialectRegistry], None]]] = None,
+) -> Iterator[None]:
+    """Requires an enclosing context from the caller and ensures relevant operations are loaded.
+
+    NOTE: If the Context is only needed inside of this contextmanager and returned values
+          don't need to the Context, use using_mlir_context!
+
+    A context must already exist before this frame is executed to ensure that any values
+    continue to live on exit. Conceptually, this prevents use-after-free issues and
+    makes the intention clear when one intends to return values tied to a Context.
+    """
+    assert (
+        ir.Context.current
+    ), "Caller must have a context so it outlives this function call."
+    with using_mlir_context(
+        required_dialects=required_dialects,
+        required_extension_operations=required_extension_operations,
+        registration_funcs=registration_funcs,
+    ):
+        # Context manager related yield
+        try:
+            yield
+        finally:
+            pass
+
+
+def with_toplevel_context(f: Callable[Param, RT]) -> Callable[Param, RT]:
+    """Decorate the function to be executed with a fresh MLIR context.
+
+    This decorator will ensure the function is executed inside a context manager for a
+    new MLIR context with upstream and IREE dialects registered. Note that each call to
+    such a function has a new context, meaning that context-owned objects from these
+    functions will not be equal to each other. All arguments and keyword arguments are
+    forwarded.
+
+    The context is destroyed before the function exits so any result from the function
+    must not depend on the context.
+    """
+
+    @wraps(f)
+    def decorator(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+        # Appending dialect registry and loading all available dialects occur on
+        # context creation because of the "_site_initialize" call.
+        with ir.Context(), ir.Location.unknown():
+            results = f(*args, **kwargs)
+        return results
+
+    return decorator
+
+
+def with_toplevel_context_create_module(
+    f: Callable[Concatenate[ir.Module, Param], RT],
+) -> Callable[Param, RT]:
+    """Decorate function to be executed in a fresh MLIR context and give it a module.
+
+    The decorated function will receive, as its leading argument, a fresh MLIR module.
+    The context manager is set up to insert operations into this module. All other
+    arguments and keyword arguments are forwarded.
+
+    The module and context are destroyed before the function exists so any result from
+    the function must not depend on either.
+    """
+
+    @with_toplevel_context
+    @wraps(f)
+    def internal(*args: Param.args, **kwargs: Param.kwargs) -> RT:
+        module = ir.Module.create()
+        with ir.InsertionPoint(module.body):
+            results = f(module, *args, **kwargs)
+        return results
+
+    return internal
+
+
+def call_with_toplevel_context(f: Callable[[], RT]) -> Callable[[], RT]:
+    """Immediately call the function in a fresh MLIR context."""
+    decorated = with_toplevel_context(f)
+    decorated()
+    return decorated
+
+
+def call_with_toplevel_context_create_module(
+    f: Callable[[ir.Module], RT],
+) -> Callable[[], RT]:
+    """Immediately call the function in a fresh MLIR context and give it a module.
+
+    The decorated function will receive, as its only argument, a fresh MLIR module. The
+    context manager is set up to insert operations into this module.
+    """
+    decorated = with_toplevel_context_create_module(f)
+    decorated()
+    return decorated
diff --git a/mlir/test/python/utils.py b/mlir/test/python/utils.py
new file mode 100644
index 0000000..8435fdd
--- /dev/null
+++ b/mlir/test/python/utils.py
@@ -0,0 +1,58 @@
+# RUN: %python %s | FileCheck %s
+
+import unittest
+
+from mlir import ir
+from mlir.dialects import arith, builtin
+from mlir.extras import types as T
+from mlir.utils import (
+    call_with_toplevel_context_create_module,
+    caller_mlir_context,
+    using_mlir_context,
+)
+
+
+class TestRequiredContext(unittest.TestCase):
+    def test_shared_context(self):
+        """Test that the context is reused, so values can be passed/returned between functions."""
+
+        @using_mlir_context()
+        def create_add(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+            return arith.AddFOp(
+                lhs, rhs, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
+            ).result
+
+        @using_mlir_context()
+        def multiple_adds(lhs: ir.Value, rhs: ir.Value) -> ir.Value:
+            return create_add(create_add(lhs, rhs), create_add(lhs, rhs))
+
+        @call_with_toplevel_context_create_module
+        def _(module) -> None:
+            c = arith.ConstantOp(value=42.42, result=ir.F32Type.get()).result
+            multiple_adds(c, c)
+
+            # CHECK: constant
+            # CHECK-NEXT: arith.addf
+            # CHECK-NEXT: arith.addf
+            # CHECK-NEXT: arith.addf
+            print(module)
+
+    def test_unregistered_op_asserts(self):
+        """Confirm that with_mlir_context fails if an operation is still not registered."""
+        with self.assertRaises(AssertionError), using_mlir_context(
+            required_extension_operations=["func.fake_extension_op"],
+            registration_funcs=[],
+        ):
+            pass
+
+    def test_required_op_asserts(self):
+        """Confirm that with_mlir_context fails if an operation is still not registered."""
+        with self.assertRaises(AssertionError), caller_mlir_context(
+            required_extension_operations=["func.fake_extension_op"],
+            registration_funcs=[],
+        ):
+            pass
+
+
+if __name__ == "__main__":
+    unittest.main()