Update FlashMLA (#32491)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -43,7 +43,7 @@ def test_sparse_flashmla_decode_smoke():
|
||||
device = torch.device("cuda")
|
||||
batch_size = 1
|
||||
seqlen_q = 1
|
||||
num_heads_q = 1
|
||||
num_heads_q = 64
|
||||
head_dim_k = 576
|
||||
head_dim_v = 512
|
||||
num_heads_k = 1
|
||||
|
||||
Reference in New Issue
Block a user