[NemotronH] Use ReplicatedLinear for fc1_latent_proj (#31807)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
@@ -210,16 +210,12 @@ class NemotronHMoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.use_latent_moe:
|
if self.use_latent_moe:
|
||||||
# TODO: check if using ReplicatedLinear is better than
|
self.fc1_latent_proj = ReplicatedLinear(
|
||||||
# ColumnParallelLinear + all_gather
|
|
||||||
self.fc1_latent_proj = ColumnParallelLinear(
|
|
||||||
input_size=config.hidden_size,
|
input_size=config.hidden_size,
|
||||||
output_size=self.moe_hidden_size,
|
output_size=self.moe_hidden_size,
|
||||||
bias=config.mlp_bias,
|
bias=config.mlp_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
disable_tp=self.is_sequence_parallel,
|
disable_tp=self.is_sequence_parallel,
|
||||||
# We need to gather the output to prepare input for moe
|
|
||||||
gather_output=True,
|
|
||||||
prefix=f"{prefix}.fc1_latent_proj",
|
prefix=f"{prefix}.fc1_latent_proj",
|
||||||
)
|
)
|
||||||
self.fc2_latent_proj = ReplicatedLinear(
|
self.fc2_latent_proj = ReplicatedLinear(
|
||||||
|
|||||||
Reference in New Issue
Block a user