[DeepSeek v3.2] Make top-k work for any logit values. (#27568)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -179,15 +179,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Optimized top-k per row operation
|
||||
ops.def(
|
||||
"top_k_per_row(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||
"top_k_per_row_prefill(Tensor logits, Tensor rowStarts, Tensor rowEnds, "
|
||||
"Tensor! indices, int numRows, int stride0, "
|
||||
"int stride1) -> ()");
|
||||
ops.impl("top_k_per_row", torch::kCUDA, &top_k_per_row);
|
||||
"int stride1, int topK) -> ()");
|
||||
ops.impl("top_k_per_row_prefill", torch::kCUDA, &top_k_per_row_prefill);
|
||||
|
||||
ops.def(
|
||||
"top_k_per_row_decode(Tensor logits, int next_n, "
|
||||
"Tensor seq_lens, Tensor! indices, int numRows, "
|
||||
"int stride0, int stride1) -> ()");
|
||||
"Tensor seq_lens, Tensor! indices, "
|
||||
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||
|
||||
// Layernorm-quant
|
||||
|
||||
Reference in New Issue
Block a user