[Performance] Cublas Bf16 Gate with Fp32 Output (#35121)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -47,7 +47,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
|
||||
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -221,73 +221,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DeepSeekV2Gate(ReplicatedLinear):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
n_experts: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
assert quant_config is None
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
n_experts,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
# Unquantized only, will be called "weight".
|
||||
assert hasattr(self, "weight")
|
||||
is_hopper_or_blackwell = current_platform.is_device_capability(
|
||||
(9, 0)
|
||||
) or current_platform.is_device_capability_family(100)
|
||||
SUPPORTED_NUM_EXPERTS = [256, 384]
|
||||
SUPPORTED_HIDDEN_SIZES = [7168]
|
||||
|
||||
self.allow_dsv3_router_gemm = (
|
||||
current_platform.is_cuda()
|
||||
and is_hopper_or_blackwell
|
||||
and n_experts in SUPPORTED_NUM_EXPERTS
|
||||
and hidden_size in SUPPORTED_HIDDEN_SIZES
|
||||
)
|
||||
|
||||
self._out_dtype: torch.dtype | None = None
|
||||
|
||||
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
|
||||
"""
|
||||
Set out dtype for the router logits. This is needed after
|
||||
__init__, b/c we need to check if the trtllm kernel is
|
||||
selected before we decide between bf16 and fp32.
|
||||
"""
|
||||
|
||||
if self._out_dtype is not None:
|
||||
raise ValueError("out_dtype has already been set")
|
||||
else:
|
||||
self._out_dtype = out_dtype
|
||||
|
||||
@property
|
||||
def out_dtype(self) -> torch.dtype:
|
||||
if self._out_dtype is None:
|
||||
raise ValueError("out_dtype has not been set yet")
|
||||
return self._out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, None]:
|
||||
"""
|
||||
Use specialized GEMM for low batch size for DSV3 and KIMI.
|
||||
"""
|
||||
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
|
||||
return ops.dsv3_router_gemm(
|
||||
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
|
||||
), None
|
||||
else:
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -316,10 +249,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.gate = DeepSeekV2Gate(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if getattr(config, "topk_method", None) == "noaux_tc":
|
||||
|
||||
@@ -34,7 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
FusedMoE,
|
||||
GateLinear,
|
||||
SharedFusedMoE,
|
||||
activation_without_mul,
|
||||
)
|
||||
@@ -148,13 +148,11 @@ class NemotronHMoE(nn.Module):
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
router_logits_dtype = torch.float32
|
||||
self.gate = ReplicatedLinear(
|
||||
self.gate = GateLinear(
|
||||
config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
params_dtype=router_logits_dtype,
|
||||
quant_config=None,
|
||||
out_dtype=torch.float32,
|
||||
force_fp32_compute=True,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
@@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module):
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts,
|
||||
is_sequence_parallel=self.is_sequence_parallel,
|
||||
router_logits_dtype=router_logits_dtype,
|
||||
routed_input_transform=self.fc1_latent_proj,
|
||||
)
|
||||
|
||||
@@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module):
|
||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
# SharedFusedMoE handles:
|
||||
# - shared experts (with original hidden_states)
|
||||
@@ -675,7 +672,7 @@ class NemotronHModel(nn.Module):
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
if self.has_moe:
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
|
||||
# - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
|
||||
# what the activation is applied to
|
||||
# - FusedMoe.w3 (aka up_proj) should be ignored since we're
|
||||
|
||||
Reference in New Issue
Block a user