[Performance] Cublas Bf16 Gate with Fp32 Output (#35121)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-02-27 02:51:28 +02:00
committed by GitHub
parent 56a6371706
commit 38c498b8e3
9 changed files with 206 additions and 80 deletions

View File

@@ -58,6 +58,10 @@ void shuffle_rows(const torch::Tensor& input_tensor,
torch::Tensor& output_tensor);
#ifndef USE_ROCM
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight);
// DeepSeek V3 optimized router GEMM kernel for SM90+
// Computes output = mat_a @ mat_b.T where:
// mat_a: [num_tokens, hidden_dim] in bf16

52
csrc/moe/router_gemm.cu Normal file
View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// bf16 x bf16 -> fp32 router GEMM via cuBLAS.
// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32,
// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cublas_v2.h>
// cuBLAS column-major math for row-major PyTorch tensors:
// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T ->
// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N
// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as
// output^T)
// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N]
// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight) {
TORCH_CHECK(input.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: input must be bfloat16");
TORCH_CHECK(weight.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: weight must be bfloat16");
TORCH_CHECK(input.dim() == 2 && weight.dim() == 2,
"router_gemm_bf16_fp32: input and weight must be 2-D");
TORCH_CHECK(input.size(1) == weight.size(1),
"router_gemm_bf16_fp32: inner dimensions must match");
int64_t const M = input.size(0);
int64_t const N = weight.size(0);
int64_t const K = input.size(1);
auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32));
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(
cublasSetStream(handle, at::cuda::getCurrentCUDAStream()));
float const alpha = 1.0f;
float const beta = 0.0f;
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast<int>(N),
static_cast<int>(M), static_cast<int>(K), &alpha, weight.data_ptr(),
CUDA_R_16BF, static_cast<int>(K), input.data_ptr(), CUDA_R_16BF,
static_cast<int>(K), &beta, out.data_ptr(), CUDA_R_32F,
static_cast<int>(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
return out;
}

View File

@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file