[Performance] Cublas Bf16 Gate with Fp32 Output (#35121)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
"Tensor)");
|
||||
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
|
||||
|
||||
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
|
||||
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
|
||||
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
|
||||
|
||||
// DeepSeek V3 optimized router GEMM for SM90+
|
||||
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
Reference in New Issue
Block a user