[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
82e11973cc
commit
afdce12c89
@@ -190,6 +190,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int numRows, int stride0, int stride1, int topK) -> ()");
|
||||
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
|
||||
|
||||
ops.def(
|
||||
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
|
||||
"Tensor? "
|
||||
"row_starts_opt) -> ()");
|
||||
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
|
||||
|
||||
// Layernorm-quant
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
||||
Reference in New Issue
Block a user