diff --git a/tests/unit/test_fmha_v3_12w.py b/tests/unit/test_fmha_v3_12w.py index 69a7ab33..1698e2d8 100644 --- a/tests/unit/test_fmha_v3_12w.py +++ b/tests/unit/test_fmha_v3_12w.py @@ -295,7 +295,7 @@ class FmhaV3: def test(): torch.manual_seed(42) - for n in [128]: + for n in [128, 256]: m, hd = 128, HEAD_DIM q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')