blob: 6e1a66834687291e94eb7c164fe187c66511e929 [file] [log] [blame]
# RUN: %PYTHON %s
"""
This script generates multi-threaded tests to check free-threading mode using CPython compiled with TSAN.
Tests can be run using pytest:
```bash
python3.13t -mpytest -vvv multithreaded_tests.py
```
IMPORTANT. Running tests are not checking the correctness, but just the execution of the tests in multi-threaded context
and passing if no warnings reported by TSAN and failing otherwise.
Details on the generated tests and execution:
1) Multi-threaded execution: all generated tests are executed independently by
a pool of threads, running each test multiple times, see @multi_threaded for details
2) Tests generation: we use existing tests: test/python/ir/*.py,
test/python/dialects/*.py, etc to generate multi-threaded tests.
In details, we perform the following:
a) we define a list of source tests to be used to generate multi-threaded tests, see `TEST_MODULES`.
b) we define `TestAllMultiThreaded` class and add existing tests to the class. See `add_existing_tests` method.
c) for each test file, we copy and modify it: test/python/ir/affine_expr.py -> /tmp/ir/affine_expr.py.
In order to import the test file as python module, we remove all executing functions, like
`@run` or `run(testMethod)`. See `copy_and_update` and `add_existing_tests` methods for details.
Observed warnings reported by TSAN.
CPython and free-threading known data-races:
1) ctypes related races: https://github.com/python/cpython/issues/127945
2) LLVM related data-races, llvm::raw_ostream is not thread-safe
- mlir pass manager
- dialects/transform_interpreter.py
- ir/diagnostic_handler.py
- ir/module.py
3) Dialect gpu module-to-binary method is unsafe
"""
import concurrent.futures
import gc
import importlib.util
import os
import sys
import threading
import tempfile
import unittest
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Optional, List
import mlir.dialects.arith as arith
from mlir.dialects import transform
from mlir.ir import Context, Location, Module, IntegerType, InsertionPoint
def import_from_path(module_name: str, file_path: Path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
def copy_and_update(src_filepath: Path, dst_filepath: Path):
# We should remove all calls like `run(testMethod)`
with open(src_filepath, "r") as reader, open(dst_filepath, "w") as writer:
while True:
src_line = reader.readline()
if len(src_line) == 0:
break
skip_lines = [
"run(",
"@run",
"@constructAndPrintInModule",
"run_apply_patterns(",
"@run_apply_patterns",
"@test_in_context",
"@construct_and_print_in_module",
]
if any(src_line.startswith(line) for line in skip_lines):
continue
writer.write(src_line)
# Helper run functions
def run(f):
f()
def run_with_context_and_location(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
f()
return f
def run_with_insertion_point(f):
print("\nTEST:", f.__name__)
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f(ctx)
print(module)
def run_with_insertion_point_v2(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f()
print(module)
return f
def run_with_insertion_point_v3(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f(module)
print(module)
return f
def run_with_insertion_point_v4(f):
print("\nTEST:", f.__name__)
with Context() as ctx, Location.unknown():
ctx.allow_unregistered_dialects = True
module = Module.create()
with InsertionPoint(module.body):
f()
return f
def run_apply_patterns(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
with InsertionPoint(apply.patterns):
f()
transform.YieldOp()
print("\nTEST:", f.__name__)
print(module)
return f
def run_transform_tensor_ext(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
sequence = transform.SequenceOp(
transform.FailurePropagationMode.Propagate,
[],
transform.AnyOpType.get(),
)
with InsertionPoint(sequence.body):
f(sequence.bodyTarget)
transform.YieldOp()
print(module)
return f
def run_transform_structured_ext(f):
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
print("\nTEST:", f.__name__)
f()
module.operation.verify()
print(module)
return f
def run_construct_and_print_in_module(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
module = f(module)
if module is not None:
print(module)
return f
TEST_MODULES = [
("execution_engine", run),
("pass_manager", run),
("dialects/affine", run_with_insertion_point_v2),
("dialects/func", run_with_insertion_point_v2),
("dialects/arith_dialect", run),
("dialects/arith_llvm", run),
("dialects/async_dialect", run),
("dialects/builtin", run),
("dialects/cf", run_with_insertion_point_v4),
("dialects/complex_dialect", run),
("dialects/func", run_with_insertion_point_v2),
("dialects/index_dialect", run_with_insertion_point),
("dialects/llvm", run_with_insertion_point_v2),
("dialects/math_dialect", run),
("dialects/memref", run),
("dialects/ml_program", run_with_insertion_point_v2),
("dialects/nvgpu", run_with_insertion_point_v2),
("dialects/nvvm", run_with_insertion_point_v2),
("dialects/ods_helpers", run),
("dialects/openmp_ops", run_with_insertion_point_v2),
("dialects/pdl_ops", run_with_insertion_point_v2),
# ("dialects/python_test", run), # TODO: Need to pass pybind11 or nanobind argv
("dialects/quant", run),
("dialects/rocdl", run_with_insertion_point_v2),
("dialects/scf", run_with_insertion_point_v2),
("dialects/shape", run),
("dialects/spirv_dialect", run),
("dialects/tensor", run),
# ("dialects/tosa", ), # Nothing to test
("dialects/transform_bufferization_ext", run_with_insertion_point_v2),
# ("dialects/transform_extras", ), # Needs a more complicated execution schema
("dialects/transform_gpu_ext", run_transform_tensor_ext),
(
"dialects/transform_interpreter",
run_with_context_and_location,
["print_", "transform_options", "failed", "include"],
),
(
"dialects/transform_loop_ext",
run_with_insertion_point_v2,
["loopOutline"],
),
("dialects/transform_memref_ext", run_with_insertion_point_v2),
("dialects/transform_nvgpu_ext", run_with_insertion_point_v2),
("dialects/transform_sparse_tensor_ext", run_transform_tensor_ext),
("dialects/transform_structured_ext", run_transform_structured_ext),
("dialects/transform_tensor_ext", run_transform_tensor_ext),
(
"dialects/transform_vector_ext",
run_apply_patterns,
["configurable_patterns"],
),
("dialects/transform", run_with_insertion_point_v3),
("dialects/vector", run_with_context_and_location),
("dialects/gpu/dialect", run_with_context_and_location),
("dialects/gpu/module-to-binary-nvvm", run_with_context_and_location),
("dialects/gpu/module-to-binary-rocdl", run_with_context_and_location),
("dialects/linalg/ops", run),
# TO ADD: No proper tests in this dialects/linalg/opsdsl/*
# ("dialects/linalg/opsdsl/*", ...),
("dialects/sparse_tensor/dialect", run),
("dialects/sparse_tensor/passes", run),
("integration/dialects/pdl", run_construct_and_print_in_module),
("integration/dialects/transform", run_construct_and_print_in_module),
("integration/dialects/linalg/opsrun", run),
("ir/affine_expr", run),
("ir/affine_map", run),
("ir/array_attributes", run),
("ir/attributes", run),
("ir/blocks", run),
("ir/builtin_types", run),
("ir/context_managers", run),
("ir/debug", run),
("ir/diagnostic_handler", run),
("ir/dialects", run),
("ir/exception", run),
("ir/insertion_point", run),
("ir/integer_set", run),
("ir/location", run),
("ir/module", run),
("ir/operation", run),
("ir/symbol_table", run),
("ir/value", run),
]
TESTS_TO_SKIP = [
"test_execution_engine__testNanoTime_multi_threaded", # testNanoTime can't run in multiple threads, even with GIL
"test_execution_engine__testSharedLibLoad_multi_threaded", # testSharedLibLoad can't run in multiple threads, even with GIL
"test_dialects_arith_dialect__testArithValue_multi_threaded", # RuntimeError: Value caster is already registered: <class 'dialects/arith_dialect.testArithValue.<locals>.ArithValue'>, even with GIL
"test_ir_dialects__testAppendPrefixSearchPath_multi_threaded", # PyGlobals::setDialectSearchPrefixes is not thread-safe, even with GIL. Strange usage of static PyGlobals vs python exposed _cext.globals
"test_ir_value__testValueCasters_multi_threaded", # RuntimeError: Value caster is already registered: <function testValueCasters.<locals>.dont_cast_int, even with GIL
# tests indirectly calling thread-unsafe llvm::raw_ostream
"test_execution_engine__testInvalidModule_multi_threaded", # mlirExecutionEngineCreate calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrAfterAll_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrBeforeAndAfterAll_multi_threaded", # IRPrinterInstrumentation::runBeforePass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrLargeLimitElements_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testPrintIrTree_multi_threaded", # IRPrinterInstrumentation::runAfterPass calls thread-unsafe llvm::raw_ostream
"test_pass_manager__testRunPipeline_multi_threaded", # PrintOpStatsPass::printSummary calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__include_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__transform_options_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) calls thread-unsafe llvm::raw_ostream
"test_dialects_transform_interpreter__print_self_multi_threaded", # mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) call thread-unsafe llvm::raw_ostream
"test_ir_diagnostic_handler__testDiagnosticCallbackException_multi_threaded", # mlirEmitError calls thread-unsafe llvm::raw_ostream
"test_ir_module__testParseSuccess_multi_threaded", # mlirOperationDump calls thread-unsafe llvm::raw_ostream
# False-positive TSAN detected race in llvm::RuntimeDyldELF::registerEHFrames()
# Details: https://github.com/llvm/llvm-project/pull/107103/files#r1905726947
"test_execution_engine__testCapsule_multi_threaded",
"test_execution_engine__testDumpToObjectFile_multi_threaded",
]
TESTS_TO_XFAIL = [
# execution_engine tests:
# - ctypes related data-races: https://github.com/python/cpython/issues/127945
"test_execution_engine__testBF16Memref_multi_threaded",
"test_execution_engine__testBasicCallback_multi_threaded",
"test_execution_engine__testComplexMemrefAdd_multi_threaded",
"test_execution_engine__testComplexUnrankedMemrefAdd_multi_threaded",
"test_execution_engine__testDynamicMemrefAdd2D_multi_threaded",
"test_execution_engine__testF16MemrefAdd_multi_threaded",
"test_execution_engine__testF8E5M2Memref_multi_threaded",
"test_execution_engine__testInvokeFloatAdd_multi_threaded",
"test_execution_engine__testInvokeVoid_multi_threaded", # a ctypes race
"test_execution_engine__testMemrefAdd_multi_threaded",
"test_execution_engine__testRankedMemRefCallback_multi_threaded",
"test_execution_engine__testRankedMemRefWithOffsetCallback_multi_threaded",
"test_execution_engine__testUnrankedMemRefCallback_multi_threaded",
"test_execution_engine__testUnrankedMemRefWithOffsetCallback_multi_threaded",
# dialects tests
"test_dialects_memref__testSubViewOpInferReturnTypeExtensiveSlicing_multi_threaded", # Related to ctypes data races
"test_dialects_transform_interpreter__print_other_multi_threaded", # Fatal Python error: Aborted or mlir::transform::PrintOp::apply(mlir::transform::TransformRewriter...) is not thread-safe
"test_dialects_gpu_module-to-binary-rocdl__testGPUToASMBin_multi_threaded", # Due to global llvm-project/llvm/lib/Target/AMDGPU/GCNSchedStrategy.cpp::GCNTrackers variable mutation
"test_dialects_gpu_module-to-binary-nvvm__testGPUToASMBin_multi_threaded",
"test_dialects_gpu_module-to-binary-nvvm__testGPUToLLVMBin_multi_threaded",
"test_dialects_gpu_module-to-binary-rocdl__testGPUToLLVMBin_multi_threaded",
# integration tests
"test_integration_dialects_linalg_opsrun__test_elemwise_builtin_multi_threaded", # Related to ctypes data races
"test_integration_dialects_linalg_opsrun__test_elemwise_generic_multi_threaded", # Related to ctypes data races
"test_integration_dialects_linalg_opsrun__test_fill_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_rng_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_fill_rng_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_max_pooling_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_max_pooling_generic_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_min_pooling_builtin_multi_threaded", # ctypes
"test_integration_dialects_linalg_opsrun__test_min_pooling_generic_multi_threaded", # ctypes
]
def add_existing_tests(test_modules, test_prefix: str = "_original_test"):
def decorator(test_cls):
this_folder = Path(__file__).parent.absolute()
test_cls.output_folder = tempfile.TemporaryDirectory()
output_folder = Path(test_cls.output_folder.name)
for test_mod_info in test_modules:
assert isinstance(test_mod_info, tuple) and len(test_mod_info) in (2, 3)
if len(test_mod_info) == 2:
test_module_name, exec_fn = test_mod_info
test_pattern = None
else:
test_module_name, exec_fn, test_pattern = test_mod_info
src_filepath = this_folder / f"{test_module_name}.py"
dst_filepath = (output_folder / f"{test_module_name}.py").absolute()
if not dst_filepath.parent.exists():
dst_filepath.parent.mkdir(parents=True)
copy_and_update(src_filepath, dst_filepath)
test_mod = import_from_path(test_module_name, dst_filepath)
for attr_name in dir(test_mod):
is_test_fn = test_pattern is None and attr_name.startswith("test")
is_test_fn |= test_pattern is not None and any(
[p in attr_name for p in test_pattern]
)
if is_test_fn:
obj = getattr(test_mod, attr_name)
if callable(obj):
test_name = f"{test_prefix}_{test_module_name.replace('/', '_')}__{attr_name}"
def wrapped_test_fn(
self, *args, __test_fn__=obj, __exec_fn__=exec_fn, **kwargs
):
__exec_fn__(__test_fn__)
setattr(test_cls, test_name, wrapped_test_fn)
return test_cls
return decorator
@contextmanager
def _capture_output(fp):
# Inspired from jax test_utils.py capture_stderr method
# ``None`` means nothing has not been captured yet.
captured = None
def get_output() -> str:
if captured is None:
raise ValueError("get_output() called while the context is active.")
return captured
with tempfile.NamedTemporaryFile(mode="w+", encoding="utf-8") as f:
original_fd = os.dup(fp.fileno())
os.dup2(f.fileno(), fp.fileno())
try:
yield get_output
finally:
# Python also has its own buffers, make sure everything is flushed.
fp.flush()
os.fsync(fp.fileno())
f.seek(0)
captured = f.read()
os.dup2(original_fd, fp.fileno())
capture_stdout = partial(_capture_output, sys.stdout)
capture_stderr = partial(_capture_output, sys.stderr)
def multi_threaded(
num_workers: int,
num_runs: int = 5,
skip_tests: Optional[List[str]] = None,
xfail_tests: Optional[List[str]] = None,
test_prefix: str = "_original_test",
multithreaded_test_postfix: str = "_multi_threaded",
):
"""Decorator that runs a test in a multi-threaded environment."""
def decorator(test_cls):
for name, test_fn in test_cls.__dict__.copy().items():
if not (name.startswith(test_prefix) and callable(test_fn)):
continue
name = f"test{name[len(test_prefix):]}"
if skip_tests is not None:
if any(
test_name.replace(multithreaded_test_postfix, "") in name
for test_name in skip_tests
):
continue
def multi_threaded_test_fn(self, *args, __test_fn__=test_fn, **kwargs):
with capture_stdout(), capture_stderr() as get_output:
barrier = threading.Barrier(num_workers)
def closure():
barrier.wait()
for _ in range(num_runs):
__test_fn__(self, *args, **kwargs)
with concurrent.futures.ThreadPoolExecutor(
max_workers=num_workers
) as executor:
futures = []
for _ in range(num_workers):
futures.append(executor.submit(closure))
# We should call future.result() to re-raise an exception if test has
# failed
assert len(list(f.result() for f in futures)) == num_workers
gc.collect()
assert Context._get_live_count() == 0
captured = get_output()
if len(captured) > 0 and "ThreadSanitizer" in captured:
raise RuntimeError(
f"ThreadSanitizer reported warnings:\n{captured}"
)
test_new_name = f"{name}{multithreaded_test_postfix}"
if xfail_tests is not None and test_new_name in xfail_tests:
multi_threaded_test_fn = unittest.expectedFailure(
multi_threaded_test_fn
)
setattr(test_cls, test_new_name, multi_threaded_test_fn)
return test_cls
return decorator
@multi_threaded(
num_workers=10,
num_runs=20,
skip_tests=TESTS_TO_SKIP,
xfail_tests=TESTS_TO_XFAIL,
)
@add_existing_tests(test_modules=TEST_MODULES, test_prefix="_original_test")
class TestAllMultiThreaded(unittest.TestCase):
@classmethod
def tearDownClass(cls):
if hasattr(cls, "output_folder"):
cls.output_folder.cleanup()
def _original_test_create_context(self):
with Context() as ctx:
print(ctx._get_live_count())
print(ctx._get_live_module_count())
print(ctx._get_live_operation_count())
print(ctx._get_live_operation_objects())
print(ctx._get_context_again() is ctx)
print(ctx._clear_live_operations())
def _original_test_create_module_with_consts(self):
py_values = [123, 234, 345]
with Context() as ctx:
module = Module.create(loc=Location.file("foo.txt", 0, 0))
dtype = IntegerType.get_signless(64)
with InsertionPoint(module.body), Location.name("a"):
arith.constant(dtype, py_values[0])
with InsertionPoint(module.body), Location.name("b"):
arith.constant(dtype, py_values[1])
with InsertionPoint(module.body), Location.name("c"):
arith.constant(dtype, py_values[2])
if __name__ == "__main__":
# Do not run the tests on CPython with GIL
if hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled():
unittest.main()