[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Roberto L. Castro
2026-02-10 16:29:52 +01:00
committed by GitHub
parent 82e11973cc
commit afdce12c89
8 changed files with 554 additions and 12 deletions

View File

@@ -190,6 +190,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int numRows, int stride0, int stride1, int topK) -> ()");
ops.impl("top_k_per_row_decode", torch::kCUDA, &top_k_per_row_decode);
ops.def(
"large_context_topk(Tensor score, Tensor indices, Tensor lengths, "
"Tensor? "
"row_starts_opt) -> ()");
ops.impl("large_context_topk", torch::kCUDA, &large_context_topk);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(