[Deepseek v3.2] Optimize top_k_per_row (#26763)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
@@ -185,7 +185,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Optimized top-k per row operation
|
||||
ops.def(
|
||||
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||
"Tensor! indices, Tensor! values, int numRows, int stride0, "
|
||||
"Tensor! indices, int numRows, int stride0, "
|
||||
"int stride1) -> ()");
|
||||
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user