[Bugfix] [DeepSeek-V3.2] fix sparse_attn_indexer padding (#32175)
Signed-off-by: Kebe <mail@kebe7jun.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -717,13 +717,20 @@ def sparse_attn_indexer(
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
|
||||
# [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens
|
||||
)
|
||||
# [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
|
||||
padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
|
||||
# [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
|
||||
padded_weights = padded_weights.flatten(0, 1)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:]
|
||||
)
|
||||
padded_weights = weights
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
@@ -739,14 +746,14 @@ def sparse_attn_indexer(
|
||||
logits = fp8_paged_mqa_logits_func(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
padded_weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
decode_metadata.schedule_metadata,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
|
||||
Reference in New Issue
Block a user