[ModelBash][DSV3] Add TRTLLM DSV3 Router GEMM kernel (6% B1 Speedup) (#34302)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw
2026-02-23 09:02:26 -05:00
committed by GitHub
parent b1b5e045df
commit 8435b2e049
9 changed files with 915 additions and 3 deletions

View File

@@ -124,6 +124,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"routed_scaling_factor, Tensor bias, int scoring_func) -> (Tensor, "
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_router_gemm", torch::kCUDA, &dsv3_router_gemm);
#endif
}