[NemotronH] Do not force router to run in fp32 (#34582)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-02-16 20:15:32 +02:00
committed by GitHub
parent 824f9e8f3c
commit 3b30e61507
2 changed files with 5 additions and 4 deletions

View File

@@ -309,6 +309,10 @@ def fi_trtllm_fp8_per_tensor_moe(
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
# The DeepSeekV3 routing method requires float32 router logits.
if routing_method_type == RoutingMethodType.DeepSeekV3:
routing_logits = routing_logits.to(torch.float32)
return flashinfer_trtllm_fp8_per_tensor_scale_moe(
routing_logits=routing_logits,
routing_bias=routing_bias,

View File

@@ -148,12 +148,10 @@ class NemotronHMoE(nn.Module):
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
router_logits_dtype = torch.float32
self.gate = ReplicatedLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
params_dtype=router_logits_dtype,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -232,7 +230,6 @@ class NemotronHMoE(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
router_logits_dtype=router_logits_dtype,
routed_input_transform=self.fc1_latent_proj,
)
@@ -244,7 +241,7 @@ class NemotronHMoE(nn.Module):
hidden_states = sequence_parallel_chunk(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
router_logits, _ = self.gate(hidden_states)
# SharedFusedMoE handles:
# - shared experts (with original hidden_states)