[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
|
||||
)
|
||||
|
||||
if self.use_latent_moe:
|
||||
# TODO: check if using ReplicatedLinear is better than
|
||||
# ColumnParallelLinear + all_gather
|
||||
self.fc1_latent_proj = ColumnParallelLinear(
|
||||
self.fc1_latent_proj = ReplicatedLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=self.moe_hidden_size,
|
||||
bias=config.mlp_bias,
|
||||
quant_config=quant_config,
|
||||
disable_tp=self.is_sequence_parallel,
|
||||
# We need to gather the output to prepare input for moe
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.fc1_latent_proj",
|
||||
)
|
||||
self.fc2_latent_proj = ReplicatedLinear(
|
||||
|
||||
Reference in New Issue
Block a user