Diag: test n=384 (3 tiles) to find crash boundary
This commit is contained in:
@@ -266,7 +266,7 @@ class FmhaV3Diag:
|
||||
|
||||
|
||||
def test():
|
||||
for n in [128, 256, 512, 1024]:
|
||||
for n in [128, 256, 384]:
|
||||
torch.manual_seed(42)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
Reference in New Issue
Block a user