diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index db0ccd695..13bb3cbd0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -717,13 +717,20 @@ def sparse_attn_indexer( # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) + + # [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim] padded_q_fp8_decode_tokens = pack_seq_triton( q_fp8[:num_decode_tokens], decode_lens ) + # [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head] + padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens) + # [bs, 1+next_n, n_head] -> [bs * next_n, n_head] + padded_weights = padded_weights.flatten(0, 1) else: padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( decode_lens.shape[0], -1, *q_fp8.shape[1:] ) + padded_weights = weights # TODO: move and optimize below logic with triton kernels batch_size = padded_q_fp8_decode_tokens.shape[0] next_n = padded_q_fp8_decode_tokens.shape[1] @@ -739,14 +746,14 @@ def sparse_attn_indexer( logits = fp8_paged_mqa_logits_func( padded_q_fp8_decode_tokens, kv_cache, - weights[:num_padded_tokens], + padded_weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, decode_metadata.schedule_metadata, max_model_len=max_model_len, ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] torch.ops._C.top_k_per_row_decode( logits,