[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2026-01-20 18:48:07 +08:00
committed by GitHub
parent 7f1bcd18ff
commit c4e5bdf61b
2 changed files with 3 additions and 3 deletions

View File

@@ -686,7 +686,7 @@ def sparse_attn_indexer(
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,