[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755
2026-01-06 18:00:40 +02:00
committed by GitHub
parent af8fd73051
commit 28c94770ad

View File

@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
)
if self.use_latent_moe:
# TODO: check if using ReplicatedLinear is better than
# ColumnParallelLinear + all_gather
self.fc1_latent_proj = ColumnParallelLinear(
self.fc1_latent_proj = ReplicatedLinear(
input_size=config.hidden_size,
output_size=self.moe_hidden_size,
bias=config.mlp_bias,
quant_config=quant_config,
disable_tp=self.is_sequence_parallel,
# We need to gather the output to prepare input for moe
gather_output=True,
prefix=f"{prefix}.fc1_latent_proj",
)
self.fc2_latent_proj = ReplicatedLinear(