[DeepSeek v3.2] Make top-k work for any logit values. (#27568)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -684,11 +684,10 @@ def sparse_attn_indexer(
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row(
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
@@ -696,6 +695,7 @@ def sparse_attn_indexer(
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if has_decode:
|
||||
@@ -738,7 +738,6 @@ def sparse_attn_indexer(
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
num_rows = logits.shape[0]
|
||||
assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
|
||||
topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]
|
||||
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
@@ -749,6 +748,7 @@ def sparse_attn_indexer(
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
|
||||
Reference in New Issue
Block a user