[DeepSeek v3.2] Make top-k work for any logit values. (#27568)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Daniel Cámpora
2025-12-08 15:55:58 +01:00
committed by GitHub
parent eb1051fb95
commit 184076c3fe
5 changed files with 643 additions and 224 deletions

View File

@@ -684,11 +684,10 @@ def sparse_attn_indexer(
chunk.cu_seqlen_ke,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row(
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
@@ -696,6 +695,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if has_decode:
@@ -738,7 +738,6 @@ def sparse_attn_indexer(
max_model_len=max_model_len,
)
num_rows = logits.shape[0]
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
@@ -749,6 +748,7 @@ def sparse_attn_indexer(
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack