Update FlashMLA (#32491)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-20 22:03:37 -07:00
committed by GitHub
parent 7ab80a8e37
commit b4f64e5b02
4 changed files with 169 additions and 42 deletions

View File

@@ -43,7 +43,7 @@ def test_sparse_flashmla_decode_smoke():
device = torch.device("cuda")
batch_size = 1
seqlen_q = 1
num_heads_q = 1
num_heads_q = 64
head_dim_k = 576
head_dim_v = 512
num_heads_k = 1