[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -686,7 +686,7 @@ def sparse_attn_indexer(
|
||||
fp8_mqa_logits_func = rocm_fp8_mqa_logits
|
||||
logits = fp8_mqa_logits_func(
|
||||
q_fp8[chunk.token_start : chunk.token_end],
|
||||
(k_fp8, k_scale.view(torch.float32)),
|
||||
(k_fp8, k_scale.view(torch.float32).flatten()),
|
||||
weights[chunk.token_start : chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
|
||||
Reference in New Issue
Block a user