[Deepseek v3.2] Optimize top_k_per_row (#26763)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora
2025-10-21 10:30:07 +02:00
committed by GitHub
parent c3a2c6ac5f
commit 80e9452984
5 changed files with 13 additions and 49 deletions

View File

@@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Optimized top-k per row operation
ops.def(
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
"Tensor! indices, Tensor! values, int numRows, int stride0, "
"Tensor! indices, int numRows, int stride0, "
"int stride1) -> ()");
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);