[Bugfix] Fix the fp8_mqa_logits dim mismatch (#32652)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2026-01-20 18:48:07 +08:00
committed by GitHub
parent 7f1bcd18ff
commit c4e5bdf61b
2 changed files with 3 additions and 3 deletions

View File

@@ -686,7 +686,7 @@ def sparse_attn_indexer(
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
(k_fp8, k_scale.view(torch.float32).flatten()),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,

View File

@@ -249,8 +249,8 @@ def fp8_mqa_logits(
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.