Diag: test n=384 (3 tiles) to find crash boundary

This commit is contained in:
2026-05-22 18:07:07 +00:00
parent 1aa4a91d01
commit 1e0805ad60

View File

@@ -266,7 +266,7 @@ class FmhaV3Diag:
def test():
for n in [128, 256, 512, 1024]:
for n in [128, 256, 384]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')