[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-01-06 18:00:40 +02:00
committed by GitHub
parent af8fd73051
commit 28c94770ad

View File

@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
) )
if self.use_latent_moe: if self.use_latent_moe:
# TODO: check if using ReplicatedLinear is better than self.fc1_latent_proj = ReplicatedLinear(
# ColumnParallelLinear + all_gather
self.fc1_latent_proj = ColumnParallelLinear(
input_size=config.hidden_size, input_size=config.hidden_size,
output_size=self.moe_hidden_size, output_size=self.moe_hidden_size,
bias=config.mlp_bias, bias=config.mlp_bias,
quant_config=quant_config, quant_config=quant_config,
disable_tp=self.is_sequence_parallel, disable_tp=self.is_sequence_parallel,
# We need to gather the output to prepare input for moe
gather_output=True,
prefix=f"{prefix}.fc1_latent_proj", prefix=f"{prefix}.fc1_latent_proj",
) )
self.fc2_latent_proj = ReplicatedLinear( self.fc2_latent_proj = ReplicatedLinear(