[Bugfix] [DeepSeek-V3.2] fix sparse_attn_indexer padding (#32175)

Signed-off-by: Kebe <mail@kebe7jun.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Kebe
2026-01-16 12:21:55 +09:00
committed by GitHub
parent 709502558c
commit 5de6dd0662

View File

@@ -717,13 +717,20 @@ def sparse_attn_indexer(
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
padded_weights = padded_weights.flatten(0, 1)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
)
padded_weights = weights
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
@@ -739,14 +746,14 @@ def sparse_attn_indexer(
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
padded_weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
decode_metadata.schedule_metadata,
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,