diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 9f638dda..681b9cdd 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -423,6 +423,8 @@ class FmhaV3StageCMulti: def test(): + import os + os.environ['CUDA_LAUNCH_BLOCKING'] = '1' torch.manual_seed(42) for n in [128]: torch.manual_seed(42)