[Deepseek v3.2] Optimize top_k_per_row (#26763)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
@@ -577,15 +577,11 @@ def sparse_attn_indexer(
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
topk_values = torch.empty(
|
||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||
)
|
||||
torch.ops._C.top_k_per_row(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
topk_values,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
@@ -642,15 +638,11 @@ def sparse_attn_indexer(
|
||||
topk_indices = torch.empty(
|
||||
num_rows, topk_tokens, dtype=torch.int32, device=logits.device
|
||||
)
|
||||
topk_values = torch.empty(
|
||||
num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
|
||||
)
|
||||
torch.ops._C.top_k_per_row(
|
||||
logits,
|
||||
torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
|
||||
index_end_pos.to(dtype=torch.int32, device=logits.device),
|
||||
topk_indices,
|
||||
topk_values,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
|
||||
Reference in New Issue
Block a user