| # RUN: %PYTHON %s |
| """ |
| This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN. |
| Tests can be run using pytest: |
| ```bash |
| python3.13t -mpytest -vvv multithreaded_tests.py |
| ``` |
| |
| IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context |
| and passing if no warnings reported by TSAN and failing otherwise. |
| |
| |
| Details on the generated tests and execution: |
| 1) Multi-threaded execution: all generated tests are executed independently by |
| a pool of threads, running each test multiple times, see @multi_threaded for details |
| |
| 2) Tests generation: we use existing tests: test/python/ir/*.py, |
| test/python/dialects/*.py, etc to generate multi-threaded tests. |
| In details, we perform the following: |
| a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`. |
| b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method. |
| c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py. |
| In order to import the test file as python module, we remove all executing functions, like |
| `@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details. |
| |
| |
| Observed warnings reported by TSAN. |
| |
| CPython and free-threading known data-races: |
| 1) ctypes related races: https://github.com/python/cpython/issues/127945 |
| 2) LLVM related data-races, llvm::raw_ostream is not thread-safe |
| - mlir pass manager |
| - dialects/transform_interpreter.py |
| - ir/diagnostic_handler.py |
| - ir/module.py |
| 3) Dialect gpu module-to-binary method is unsafe |
| """ |
| import concurrent.futures |
| import gc |
| import importlib.util |
| import os |
| import sys |
| import threading |
| import tempfile |
| import unittest |
| |
| from contextlib import contextmanager |
| from functools import partial |
| from pathlib import Path |
| from typing import Optional, List |
| |
| import mlir.dialects.arith as arith |
| from mlir.dialects import transform |
| from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint |
| |
| |
| def import_from_path(module_name: str, file_path: Path): |
| spec = importlib.util.spec_from_file_location(module_name, file_path) |
| module = importlib.util.module_from_spec(spec) |
| sys.modules[module_name] = module |
| spec.loader.exec_module(module) |
| return module |
| |
| |
| def copy_and_update(src_filepath: Path, dst_filepath: Path): |
| # We should remove all calls like `run(testMethod)` |
| with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer: |
| while True: |
| src_line = reader.readline() |
| if len(src_line) == 0: |
| break |
| skip_lines = [ |
| "run(", |
| "@run", |
| "@constructAndPrintInModule", |
| "run_apply_patterns(", |
| "@run_apply_patterns", |
| "@test_in_context", |
| "@construct_and_print_in_module", |
| ] |
| if any(src_line.startswith(line) for line in skip_lines): |
| continue |
| writer.write(src_line) |
| |
| |
| # Helper run functions |
| def run(f): |
| f() |
| |
| |
| def run_with_context_and_location(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| f() |
| return f |
| |
| |
| def run_with_insertion_point(f): |
| print("\nTEST:", f.__name__) |
| with Context() as ctx, Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| f(ctx) |
| print(module) |
| |
| |
| def run_with_insertion_point_v2(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| f() |
| print(module) |
| return f |
| |
| |
| def run_with_insertion_point_v3(f): |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| print("\nTEST:", f.__name__) |
| f(module) |
| print(module) |
| return f |
| |
| |
| def run_with_insertion_point_v4(f): |
| print("\nTEST:", f.__name__) |
| with Context() as ctx, Location.unknown(): |
| ctx.allow_unregistered_dialects = True |
| module = Module.create() |
| with InsertionPoint(module.body): |
| f() |
| return f |
| |
| |
| def run_apply_patterns(f): |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.AnyOpType.get(), |
| ) |
| with InsertionPoint(sequence.body): |
| apply = transform.ApplyPatternsOp(sequence.bodyTarget) |
| with InsertionPoint(apply.patterns): |
| f() |
| transform.YieldOp() |
| print("\nTEST:", f.__name__) |
| print(module) |
| return f |
| |
| |
| def run_transform_tensor_ext(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| sequence = transform.SequenceOp( |
| transform.FailurePropagationMode.Propagate, |
| [], |
| transform.AnyOpType.get(), |
| ) |
| with InsertionPoint(sequence.body): |
| f(sequence.bodyTarget) |
| transform.YieldOp() |
| print(module) |
| return f |
| |
| |
| def run_transform_structured_ext(f): |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| print("\nTEST:", f.__name__) |
| f() |
| module.operation.verify() |
| print(module) |
| return f |
| |
| |
| def run_construct_and_print_in_module(f): |
| print("\nTEST:", f.__name__) |
| with Context(), Location.unknown(): |
| module = Module.create() |
| with InsertionPoint(module.body): |
| module = f(module) |
| if module is not None: |
| print(module) |
| return f |
| |
| |
| TEST_MODULES = [ |
| ("execution_engine", run), |
| ("pass_manager", run), |
| ("dialects/affine", run_with_insertion_point_v2), |
| ("dialects/func", run_with_insertion_point_v2), |
| ("dialects/arith_dialect", run), |
| ("dialects/arith_llvm", run), |
| ("dialects/async_dialect", run), |
| ("dialects/builtin", run), |
| ("dialects/cf", run_with_insertion_point_v4), |
| ("dialects/complex_dialect", run), |
| ("dialects/func", run_with_insertion_point_v2), |
| ("dialects/index_dialect", run_with_insertion_point), |
| ("dialects/llvm", run_with_insertion_point_v2), |
| ("dialects/math_dialect", run), |
| ("dialects/memref", run), |
| ("dialects/ml_program", run_with_insertion_point_v2), |
| ("dialects/nvgpu", run_with_insertion_point_v2), |
| ("dialects/nvvm", run_with_insertion_point_v2), |
| ("dialects/ods_helpers", run), |
| ("dialects/openmp_ops", run_with_insertion_point_v2), |
| ("dialects/pdl_ops", run_with_insertion_point_v2), |
| # ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv |
| ("dialects/quant", run), |
| ("dialects/rocdl", run_with_insertion_point_v2), |
| ("dialects/scf", run_with_insertion_point_v2), |
| ("dialects/shape", run), |
| ("dialects/spirv_dialect", run), |
| ("dialects/tensor", run), |
| # ("dialects/tosa", ), # Nothing to test |
| ("dialects/transform_bufferization_ext", run_with_insertion_point_v2), |
| # ("dialects/transform_extras", ), # Needs a more complicated execution schema |
| ("dialects/transform_gpu_ext", run_transform_tensor_ext), |
| ( |
| "dialects/transform_interpreter", |
| run_with_context_and_location, |
| ["print_", "transform_options", "failed", "include"], |
| ), |
| ( |
| "dialects/transform_loop_ext", |
| run_with_insertion_point_v2, |
| ["loopOutline"], |
| ), |
| ("dialects/transform_memref_ext", run_with_insertion_point_v2), |
| ("dialects/transform_nvgpu_ext", run_with_insertion_point_v2), |
| ("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext), |
| ("dialects/transform_structured_ext", run_transform_structured_ext), |
| ("dialects/transform_tensor_ext", run_transform_tensor_ext), |
| ( |
| "dialects/transform_vector_ext", |
| run_apply_patterns, |
| ["configurable_patterns"], |
| ), |
| ("dialects/transform", run_with_insertion_point_v3), |
| ("dialects/vector", run_with_context_and_location), |
| ("dialects/gpu/dialect", run_with_context_and_location), |
| ("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location), |
| ("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location), |
| ("dialects/linalg/ops", run), |
| # TO ADD: No proper tests in this dialects/linalg/opsdsl/* |
| # ("dialects/linalg/opsdsl/*", ...), |
| ("dialects/sparse_tensor/dialect", run), |
| ("dialects/sparse_tensor/passes", run), |
| ("integration/dialects/pdl", run_construct_and_print_in_module), |
| ("integration/dialects/transform", run_construct_and_print_in_module), |
| ("integration/dialects/linalg/opsrun", run), |
| ("ir/affine_expr", run), |
| ("ir/affine_map", run), |
| ("ir/array_attributes", run), |
| ("ir/attributes", run), |
| ("ir/blocks", run), |
| ("ir/builtin_types", run), |
| ("ir/context_managers", run), |
| ("ir/debug", run), |
| ("ir/diagnostic_handler", run), |
| ("ir/dialects", run), |
| ("ir/exception", run), |
| ("ir/insertion_point", run), |
| ("ir/integer_set", run), |
| ("ir/location", run), |
| ("ir/module", run), |
| ("ir/operation", run), |
| ("ir/symbol_table", run), |
| ("ir/value", run), |
| ] |
| |
| TESTS_TO_SKIP = [ |
| "test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL |
| "test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL |
| "test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL |
| "test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals |
| "test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL |
| # tests indirectly calling thread-unsafe llvm::raw_ostream |
| "test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream |
| "test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream |
| "test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream |
| "test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream |
| "test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream |
| "test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream |
| "test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream |
| "test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream |
| "test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream |
| "test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream |
| "test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream |
| # False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames() |
| # Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947 |
| "test_execution_engine__testCapsule_multi_threaded", |
| "test_execution_engine__testDumpToObjectFile_multi_threaded", |
| ] |
| |
| TESTS_TO_XFAIL = [ |
| # execution_engine tests: |
| # - ctypes related data-races: https://github.com/python/cpython/issues/127945 |
| "test_execution_engine__testBF16Memref_multi_threaded", |
| "test_execution_engine__testBasicCallback_multi_threaded", |
| "test_execution_engine__testComplexMemrefAdd_multi_threaded", |
| "test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded", |
| "test_execution_engine__testDynamicMemrefAdd2D_multi_threaded", |
| "test_execution_engine__testF16MemrefAdd_multi_threaded", |
| "test_execution_engine__testF8E5M2Memref_multi_threaded", |
| "test_execution_engine__testInvokeFloatAdd_multi_threaded", |
| "test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race |
| "test_execution_engine__testMemrefAdd_multi_threaded", |
| "test_execution_engine__testRankedMemRefCallback_multi_threaded", |
| "test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded", |
| "test_execution_engine__testUnrankedMemRefCallback_multi_threaded", |
| "test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded", |
| # dialects tests |
| "test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races |
| "test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe |
| "test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation |
| "test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded", |
| "test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded", |
| "test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded", |
| # integration tests |
| "test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races |
| "test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races |
| "test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes |
| "test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes |
| ] |
| |
| |
| def add_existing_tests(test_modules, test_prefix: str = "_original_test"): |
| def decorator(test_cls): |
| this_folder = Path(__file__).parent.absolute() |
| test_cls.output_folder = tempfile.TemporaryDirectory() |
| output_folder = Path(test_cls.output_folder.name) |
| |
| for test_mod_info in test_modules: |
| assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3) |
| if len(test_mod_info) == 2: |
| test_module_name, exec_fn = test_mod_info |
| test_pattern = None |
| else: |
| test_module_name, exec_fn, test_pattern = test_mod_info |
| |
| src_filepath = this_folder / f"{test_module_name}.py" |
| dst_filepath = (output_folder / f"{test_module_name}.py").absolute() |
| if not dst_filepath.parent.exists(): |
| dst_filepath.parent.mkdir(parents=True) |
| copy_and_update(src_filepath, dst_filepath) |
| test_mod = import_from_path(test_module_name, dst_filepath) |
| for attr_name in dir(test_mod): |
| is_test_fn = test_pattern is None and attr_name.startswith("test") |
| is_test_fn |= test_pattern is not None and any( |
| [p in attr_name for p in test_pattern] |
| ) |
| if is_test_fn: |
| obj = getattr(test_mod, attr_name) |
| if callable(obj): |
| test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}" |
| |
| def wrapped_test_fn( |
| self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs |
| ): |
| __exec_fn__(__test_fn__) |
| |
| setattr(test_cls, test_name, wrapped_test_fn) |
| return test_cls |
| |
| return decorator |
| |
| |
| @contextmanager |
| def _capture_output(fp): |
| # Inspired from jax test_utils.py capture_stderr method |
| # ``None`` means nothing has not been captured yet. |
| captured = None |
| |
| def get_output() -> str: |
| if captured is None: |
| raise ValueError("get_output() called while the context is active.") |
| return captured |
| |
| with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f: |
| original_fd = os.dup(fp.fileno()) |
| os.dup2(f.fileno(), fp.fileno()) |
| try: |
| yield get_output |
| finally: |
| # Python also has its own buffers, make sure everything is flushed. |
| fp.flush() |
| os.fsync(fp.fileno()) |
| f.seek(0) |
| captured = f.read() |
| os.dup2(original_fd, fp.fileno()) |
| |
| |
| capture_stdout = partial(_capture_output, sys.stdout) |
| capture_stderr = partial(_capture_output, sys.stderr) |
| |
| |
| def multi_threaded( |
| num_workers: int, |
| num_runs: int = 5, |
| skip_tests: Optional[List[str]] = None, |
| xfail_tests: Optional[List[str]] = None, |
| test_prefix: str = "_original_test", |
| multithreaded_test_postfix: str = "_multi_threaded", |
| ): |
| """Decorator that runs a test in a multi-threaded environment.""" |
| |
| def decorator(test_cls): |
| for name, test_fn in test_cls.__dict__.copy().items(): |
| if not (name.startswith(test_prefix) and callable(test_fn)): |
| continue |
| |
| name = f"test{name[len(test_prefix):]}" |
| if skip_tests is not None: |
| if any( |
| test_name.replace(multithreaded_test_postfix, "") in name |
| for test_name in skip_tests |
| ): |
| continue |
| |
| def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs): |
| with capture_stdout(), capture_stderr() as get_output: |
| barrier = threading.Barrier(num_workers) |
| |
| def closure(): |
| barrier.wait() |
| for _ in range(num_runs): |
| __test_fn__(self, *args, **kwargs) |
| |
| with concurrent.futures.ThreadPoolExecutor( |
| max_workers=num_workers |
| ) as executor: |
| futures = [] |
| for _ in range(num_workers): |
| futures.append(executor.submit(closure)) |
| # We should call future.result() to re-raise an exception if test has |
| # failed |
| assert len(list(f.result() for f in futures)) == num_workers |
| |
| gc.collect() |
| assert Context._get_live_count() == 0 |
| |
| captured = get_output() |
| if len(captured) > 0 and "ThreadSanitizer" in captured: |
| raise RuntimeError( |
| f"ThreadSanitizer reported warnings:\n{captured}" |
| ) |
| |
| test_new_name = f"{name}{multithreaded_test_postfix}" |
| if xfail_tests is not None and test_new_name in xfail_tests: |
| multi_threaded_test_fn = unittest.expectedFailure( |
| multi_threaded_test_fn |
| ) |
| |
| setattr(test_cls, test_new_name, multi_threaded_test_fn) |
| |
| return test_cls |
| |
| return decorator |
| |
| |
| @multi_threaded( |
| num_workers=10, |
| num_runs=20, |
| skip_tests=TESTS_TO_SKIP, |
| xfail_tests=TESTS_TO_XFAIL, |
| ) |
| @add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test") |
| class TestAllMultiThreaded(unittest.TestCase): |
| @classmethod |
| def tearDownClass(cls): |
| if hasattr(cls, "output_folder"): |
| cls.output_folder.cleanup() |
| |
| def _original_test_create_context(self): |
| with Context() as ctx: |
| print(ctx._get_live_count()) |
| print(ctx._get_live_module_count()) |
| print(ctx._get_live_operation_count()) |
| print(ctx._get_live_operation_objects()) |
| print(ctx._get_context_again() is ctx) |
| print(ctx._clear_live_operations()) |
| |
| def _original_test_create_module_with_consts(self): |
| py_values = [123, 234, 345] |
| with Context() as ctx: |
| module = Module.create(loc=Location.file("foo.txt", 0, 0)) |
| |
| dtype = IntegerType.get_signless(64) |
| with InsertionPoint(module.body), Location.name("a"): |
| arith.constant(dtype, py_values[0]) |
| |
| with InsertionPoint(module.body), Location.name("b"): |
| arith.constant(dtype, py_values[1]) |
| |
| with InsertionPoint(module.body), Location.name("c"): |
| arith.constant(dtype, py_values[2]) |
| |
| |
| if __name__ == "__main__": |
| # Do not run the tests on CPython with GIL |
| if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled(): |
| unittest.main() |