[Performance] Cublas Bf16 Gate with Fp32 Output (#35121)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-02-27 02:51:28 +02:00
committed by GitHub
parent 56a6371706
commit 38c498b8e3
9 changed files with 206 additions and 80 deletions

View File

@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file