diag: test n=128 and n=256 both with rescale disabled

This commit is contained in:
2026-05-23 01:12:00 +00:00
parent dc44fa187a
commit 0ef41266de

View File

@@ -450,7 +450,7 @@ class FmhaV3StageCMulti:
def test():
torch.manual_seed(42)
for n in [256]:
for n in [128, 256]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')