diff --git a/CMakeLists.txt b/CMakeLists.txt index 39714b846..479d6db1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -971,7 +971,8 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" - "csrc/moe/grouped_topk_kernels.cu") + "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/router_gemm.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index b71db3569..d8d962887 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -58,6 +58,10 @@ void shuffle_rows(const torch::Tensor& input_tensor, torch::Tensor& output_tensor); #ifndef USE_ROCM +// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16) +torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, + torch::Tensor const& weight); + // DeepSeek V3 optimized router GEMM kernel for SM90+ // Computes output = mat_a @ mat_b.T where: // mat_a: [num_tokens, hidden_dim] in bf16 diff --git a/csrc/moe/router_gemm.cu b/csrc/moe/router_gemm.cu new file mode 100644 index 000000000..a939f8846 --- /dev/null +++ b/csrc/moe/router_gemm.cu @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +// bf16 x bf16 -> fp32 router GEMM via cuBLAS. +// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32, +// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp. + +#include +#include +#include + +// cuBLAS column-major math for row-major PyTorch tensors: +// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T -> +// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N +// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as +// output^T) +// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N] +// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output) + +torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, + torch::Tensor const& weight) { + TORCH_CHECK(input.dtype() == torch::kBFloat16, + "router_gemm_bf16_fp32: input must be bfloat16"); + TORCH_CHECK(weight.dtype() == torch::kBFloat16, + "router_gemm_bf16_fp32: weight must be bfloat16"); + TORCH_CHECK(input.dim() == 2 && weight.dim() == 2, + "router_gemm_bf16_fp32: input and weight must be 2-D"); + TORCH_CHECK(input.size(1) == weight.size(1), + "router_gemm_bf16_fp32: inner dimensions must match"); + + int64_t const M = input.size(0); + int64_t const N = weight.size(0); + int64_t const K = input.size(1); + + auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32)); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK( + cublasSetStream(handle, at::cuda::getCurrentCUDAStream())); + + float const alpha = 1.0f; + float const beta = 0.0f; + + TORCH_CUDABLAS_CHECK(cublasGemmEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast(N), + static_cast(M), static_cast(K), &alpha, weight.data_ptr(), + CUDA_R_16BF, static_cast(K), input.data_ptr(), CUDA_R_16BF, + static_cast(K), &beta, out.data_ptr(), CUDA_R_32F, + static_cast(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); + + return out; +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 438599451..7b627a6f8 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor)"); m.impl("grouped_topk", torch::kCUDA, &grouped_topk); + // cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16) + m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor"); + m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32); + // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 37cf43620..69f080ae2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2190,6 +2190,23 @@ def moe_wna16_gemm( ) +def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K).""" + return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) + + +if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): + + @register_fake("_moe_C::router_gemm_bf16_fp32") + def router_gemm_bf16_fp32_fake( + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device + ) + + def dsv3_router_gemm( hidden_states: torch.Tensor, router_weight: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index c6cb31b62..be901bd24 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, @@ -64,6 +65,7 @@ __all__ = [ "FusedMoEPermuteExpertsUnpermute", "FusedMoEActivationFormat", "FusedMoEPrepareAndFinalize", + "GateLinear", "RoutingMethodType", "SharedFusedMoE", "ZeroExpertFusedMoE", diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py new file mode 100644 index 000000000..77d8e7560 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.platforms import current_platform + + +@PluggableLayer.register("gate_linear") +class GateLinear(ReplicatedLinear): + """MoE gate linear layer with three-tier GEMM dispatch: + + 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) + 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 3. F.linear via ReplicatedLinear (ultimate fallback) + + The ``out_dtype`` attribute is mutable and can be set after init + (e.g. when the required dtype depends on the expert quantization + method which is only known later). + """ + + # Dimensions supported by the DSV3 specialized kernel + DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] + DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + out_dtype: torch.dtype | None = None, + params_dtype: torch.dtype | None = None, + force_fp32_compute: bool = False, + prefix: str = "", + ): + is_hopper_or_blackwell = current_platform.is_device_capability( + (9, 0) + ) or current_platform.is_device_capability_family(100) + can_use_specialized_kernels = ( + current_platform.is_cuda() and is_hopper_or_blackwell and not bias + ) + + # If fp32 compute is required and no specialized kernel is available, + # store weights in fp32 so Tier 3 computes in fp32 natively. + if force_fp32_compute and not can_use_specialized_kernels: + params_dtype = torch.float32 + + super().__init__( + input_size, + output_size, + bias=bias, + params_dtype=params_dtype, + quant_config=None, + prefix=prefix, + ) + self.out_dtype = out_dtype + + # DSV3 specialized kernel eligibility (SM90+, exact dims) + self.allow_specialized_router_gemm = can_use_specialized_kernels + self.allow_dsv3_router_gemm = ( + self.allow_specialized_router_gemm + and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS + and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES + ) + + # cuBLAS bf16→fp32 eligibility + self.allow_cublas_router_gemm = ( + self.allow_specialized_router_gemm + and self.weight.dtype == torch.bfloat16 + and self.out_dtype == torch.float32 + ) + + def set_out_dtype(self, out_dtype: torch.dtype) -> None: + """Set output dtype for the router logits after init. + + Useful when the required dtype depends on the expert quantization + method which is only known after the gate is constructed. + """ + if self.out_dtype is not None: + raise ValueError("out_dtype has already been set") + self.out_dtype = out_dtype + + if ( + not self.allow_cublas_router_gemm + and self.allow_specialized_router_gemm + and out_dtype == torch.float32 + ): + self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 + + def forward( + self, x: torch.Tensor + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + import vllm._custom_ops as ops + + # Tier 1: DSV3 specialized kernel + if self.allow_dsv3_router_gemm and x.shape[0] <= 16: + output = ops.dsv3_router_gemm( + hidden_states=x, + router_weight=self.weight, + output_dtype=self.out_dtype, + ) + return output, None + + # Tier 2: cuBLAS bf16→fp32 + if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: + output = ops.router_gemm_bf16_fp32(x, self.weight) + return output, None + + # Tier 3: F.linear (ReplicatedLinear) + if self.out_dtype is not None and x.dtype != self.weight.dtype: + x = x.to(self.weight.dtype) + output, output_bias = super().forward(x) + if self.out_dtype is not None and output.dtype != self.out_dtype: + output = output.to(self.out_dtype) + return output, output_bias diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 768f4e20b..c3e1ddb7d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -47,7 +47,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -221,73 +221,6 @@ class DeepseekV2MLP(nn.Module): return x -class DeepSeekV2Gate(ReplicatedLinear): - def __init__( - self, - hidden_size: int, - n_experts: int, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - assert quant_config is None - super().__init__( - hidden_size, - n_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - - # Unquantized only, will be called "weight". - assert hasattr(self, "weight") - is_hopper_or_blackwell = current_platform.is_device_capability( - (9, 0) - ) or current_platform.is_device_capability_family(100) - SUPPORTED_NUM_EXPERTS = [256, 384] - SUPPORTED_HIDDEN_SIZES = [7168] - - self.allow_dsv3_router_gemm = ( - current_platform.is_cuda() - and is_hopper_or_blackwell - and n_experts in SUPPORTED_NUM_EXPERTS - and hidden_size in SUPPORTED_HIDDEN_SIZES - ) - - self._out_dtype: torch.dtype | None = None - - def set_out_dtype(self, out_dtype: torch.dtype) -> None: - """ - Set out dtype for the router logits. This is needed after - __init__, b/c we need to check if the trtllm kernel is - selected before we decide between bf16 and fp32. - """ - - if self._out_dtype is not None: - raise ValueError("out_dtype has already been set") - else: - self._out_dtype = out_dtype - - @property - def out_dtype(self) -> torch.dtype: - if self._out_dtype is None: - raise ValueError("out_dtype has not been set yet") - return self._out_dtype - - def forward( - self, - x: torch.Tensor, - ) -> tuple[torch.Tensor, None]: - """ - Use specialized GEMM for low batch size for DSV3 and KIMI. - """ - if self.allow_dsv3_router_gemm and x.shape[0] <= 16: - return ops.dsv3_router_gemm( - hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype - ), None - else: - return super().forward(x) - - class DeepseekV2MoE(nn.Module): def __init__( self, @@ -316,10 +249,9 @@ class DeepseekV2MoE(nn.Module): "Only silu is supported for now." ) - self.gate = DeepSeekV2Gate( + self.gate = GateLinear( config.hidden_size, config.n_routed_experts, - quant_config=None, prefix=f"{prefix}.gate", ) if getattr(config, "topk_method", None) == "noaux_tc": diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 446b01fe3..39ea0ea48 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -34,7 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import ( - FusedMoE, + GateLinear, SharedFusedMoE, activation_without_mul, ) @@ -148,13 +148,11 @@ class NemotronHMoE(nn.Module): self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe - router_logits_dtype = torch.float32 - self.gate = ReplicatedLinear( + self.gate = GateLinear( config.hidden_size, config.n_routed_experts, - bias=False, - params_dtype=router_logits_dtype, - quant_config=None, + out_dtype=torch.float32, + force_fp32_compute=True, prefix=f"{prefix}.gate", ) @@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, - router_logits_dtype=router_logits_dtype, routed_input_transform=self.fc1_latent_proj, ) @@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module): hidden_states = sequence_parallel_chunk(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + router_logits, _ = self.gate(hidden_states) # SharedFusedMoE handles: # - shared experts (with original hidden_states) @@ -675,7 +672,7 @@ class NemotronHModel(nn.Module): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: if self.has_moe: # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's # what the activation is applied to # - FusedMoe.w3 (aka up_proj) should be ignored since we're