[Performance] Cublas Bf16 Gate with Fp32 Output (#35121)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-02-27 02:51:28 +02:00
committed by GitHub
parent 56a6371706
commit 38c498b8e3
9 changed files with 206 additions and 80 deletions

View File

@@ -971,7 +971,8 @@ set(VLLM_MOE_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu")
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu")
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@@ -58,6 +58,10 @@ void shuffle_rows(const torch::Tensor& input_tensor,
torch::Tensor& output_tensor);
#ifndef USE_ROCM
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight);
// DeepSeek V3 optimized router GEMM kernel for SM90+
// Computes output = mat_a @ mat_b.T where:
// mat_a: [num_tokens, hidden_dim] in bf16

52
csrc/moe/router_gemm.cu Normal file
View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
// bf16 x bf16 -> fp32 router GEMM via cuBLAS.
// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32,
// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp.
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cublas_v2.h>
// cuBLAS column-major math for row-major PyTorch tensors:
// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T ->
// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N
// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as
// output^T)
// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N]
// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output)
torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input,
torch::Tensor const& weight) {
TORCH_CHECK(input.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: input must be bfloat16");
TORCH_CHECK(weight.dtype() == torch::kBFloat16,
"router_gemm_bf16_fp32: weight must be bfloat16");
TORCH_CHECK(input.dim() == 2 && weight.dim() == 2,
"router_gemm_bf16_fp32: input and weight must be 2-D");
TORCH_CHECK(input.size(1) == weight.size(1),
"router_gemm_bf16_fp32: inner dimensions must match");
int64_t const M = input.size(0);
int64_t const N = weight.size(0);
int64_t const K = input.size(1);
auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32));
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
TORCH_CUDABLAS_CHECK(
cublasSetStream(handle, at::cuda::getCurrentCUDAStream()));
float const alpha = 1.0f;
float const beta = 0.0f;
TORCH_CUDABLAS_CHECK(cublasGemmEx(
handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast<int>(N),
static_cast<int>(M), static_cast<int>(K), &alpha, weight.data_ptr(),
CUDA_R_16BF, static_cast<int>(K), input.data_ptr(), CUDA_R_16BF,
static_cast<int>(K), &beta, out.data_ptr(), CUDA_R_32F,
static_cast<int>(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT));
return out;
}

View File

@@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor)");
m.impl("grouped_topk", torch::kCUDA, &grouped_topk);
// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16)
m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor");
m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32);
// DeepSeek V3 optimized router GEMM for SM90+
m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
// conditionally compiled so impl registration is in source file

View File

@@ -2190,6 +2190,23 @@ def moe_wna16_gemm(
)
def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K)."""
return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):
@register_fake("_moe_C::router_gemm_bf16_fp32")
def router_gemm_bf16_fp32_fake(
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty(
input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device
)
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weight: torch.Tensor,

View File

@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
@@ -64,6 +65,7 @@ __all__ = [
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",

View File

@@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.platforms import current_platform
@PluggableLayer.register("gate_linear")
class GateLinear(ReplicatedLinear):
"""MoE gate linear layer with three-tier GEMM dispatch:
1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
3. F.linear via ReplicatedLinear (ultimate fallback)
The ``out_dtype`` attribute is mutable and can be set after init
(e.g. when the required dtype depends on the expert quantization
method which is only known later).
"""
# Dimensions supported by the DSV3 specialized kernel
DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
out_dtype: torch.dtype | None = None,
params_dtype: torch.dtype | None = None,
force_fp32_compute: bool = False,
prefix: str = "",
):
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
can_use_specialized_kernels = (
current_platform.is_cuda() and is_hopper_or_blackwell and not bias
)
# If fp32 compute is required and no specialized kernel is available,
# store weights in fp32 so Tier 3 computes in fp32 natively.
if force_fp32_compute and not can_use_specialized_kernels:
params_dtype = torch.float32
super().__init__(
input_size,
output_size,
bias=bias,
params_dtype=params_dtype,
quant_config=None,
prefix=prefix,
)
self.out_dtype = out_dtype
# DSV3 specialized kernel eligibility (SM90+, exact dims)
self.allow_specialized_router_gemm = can_use_specialized_kernels
self.allow_dsv3_router_gemm = (
self.allow_specialized_router_gemm
and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
)
# cuBLAS bf16→fp32 eligibility
self.allow_cublas_router_gemm = (
self.allow_specialized_router_gemm
and self.weight.dtype == torch.bfloat16
and self.out_dtype == torch.float32
)
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""Set output dtype for the router logits after init.
Useful when the required dtype depends on the expert quantization
method which is only known after the gate is constructed.
"""
if self.out_dtype is not None:
raise ValueError("out_dtype has already been set")
self.out_dtype = out_dtype
if (
not self.allow_cublas_router_gemm
and self.allow_specialized_router_gemm
and out_dtype == torch.float32
):
self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
def forward(
self, x: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
import vllm._custom_ops as ops
# Tier 1: DSV3 specialized kernel
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
output = ops.dsv3_router_gemm(
hidden_states=x,
router_weight=self.weight,
output_dtype=self.out_dtype,
)
return output, None
# Tier 2: cuBLAS bf16→fp32
if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
output = ops.router_gemm_bf16_fp32(x, self.weight)
return output, None
# Tier 3: F.linear (ReplicatedLinear)
if self.out_dtype is not None and x.dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
output, output_bias = super().forward(x)
if self.out_dtype is not None and output.dtype != self.out_dtype:
output = output.to(self.out_dtype)
return output, output_bias

View File

@@ -47,7 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -221,73 +221,6 @@ class DeepseekV2MLP(nn.Module):
return x
class DeepSeekV2Gate(ReplicatedLinear):
def __init__(
self,
hidden_size: int,
n_experts: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
assert quant_config is None
super().__init__(
hidden_size,
n_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
# Unquantized only, will be called "weight".
assert hasattr(self, "weight")
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
SUPPORTED_NUM_EXPERTS = [256, 384]
SUPPORTED_HIDDEN_SIZES = [7168]
self.allow_dsv3_router_gemm = (
current_platform.is_cuda()
and is_hopper_or_blackwell
and n_experts in SUPPORTED_NUM_EXPERTS
and hidden_size in SUPPORTED_HIDDEN_SIZES
)
self._out_dtype: torch.dtype | None = None
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""
Set out dtype for the router logits. This is needed after
__init__, b/c we need to check if the trtllm kernel is
selected before we decide between bf16 and fp32.
"""
if self._out_dtype is not None:
raise ValueError("out_dtype has already been set")
else:
self._out_dtype = out_dtype
@property
def out_dtype(self) -> torch.dtype:
if self._out_dtype is None:
raise ValueError("out_dtype has not been set yet")
return self._out_dtype
def forward(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""
Use specialized GEMM for low batch size for DSV3 and KIMI.
"""
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
return ops.dsv3_router_gemm(
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
), None
else:
return super().forward(x)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
@@ -316,10 +249,9 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = DeepSeekV2Gate(
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
quant_config=None,
prefix=f"{prefix}.gate",
)
if getattr(config, "topk_method", None) == "noaux_tc":

View File

@@ -34,7 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
GateLinear,
SharedFusedMoE,
activation_without_mul,
)
@@ -148,13 +148,11 @@ class NemotronHMoE(nn.Module):
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
router_logits_dtype = torch.float32
self.gate = ReplicatedLinear(
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
params_dtype=router_logits_dtype,
quant_config=None,
out_dtype=torch.float32,
force_fp32_compute=True,
prefix=f"{prefix}.gate",
)
@@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
router_logits_dtype=router_logits_dtype,
routed_input_transform=self.fc1_latent_proj,
)
@@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)
# SharedFusedMoE handles:
# - shared experts (with original hidden_states)
@@ -675,7 +672,7 @@ class NemotronHModel(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
if self.has_moe:
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
# what the activation is applied to
# - FusedMoe.w3 (aka up_proj) should be ignored since we're