[Deepseek v3.2] Optimize top_k_per_row (#26763)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora
2025-10-21 10:30:07 +02:00
committed by GitHub
parent c3a2c6ac5f
commit 80e9452984
5 changed files with 13 additions and 49 deletions

View File

@@ -577,15 +577,11 @@ def sparse_attn_indexer(
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
topk_values,
num_rows,
logits.stride(0),
logits.stride(1),
@@ -642,15 +638,11 @@ def sparse_attn_indexer(
topk_indices = torch.empty(
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
)
topk_values = torch.empty(
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
)
torch.ops._C.top_k_per_row(
logits,
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
index_end_pos.to(dtype=torch.int32, device=logits.device),
topk_indices,
topk_values,
num_rows,
logits.stride(0),
logits.stride(1),