diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 51360a9e..30552aa8 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -450,7 +450,7 @@ class FmhaV3StageCMulti: def test(): torch.manual_seed(42) - for n in [256]: + for n in [128, 256]: torch.manual_seed(42) m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')