[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -83,6 +83,7 @@ class MPTAttention(nn.Module):
self.total_num_kv_heads,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.Wqkv",
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@@ -92,6 +93,7 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
tp_world_size = get_tensor_model_parallel_world_size()
@@ -152,6 +154,7 @@ class MPTMLP(nn.Module):
self,
config: MptConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.d_model
@@ -162,6 +165,7 @@ class MPTMLP(nn.Module):
intermediate_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(
@@ -169,6 +173,7 @@ class MPTMLP(nn.Module):
hidden_size,
bias=not config.no_bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -193,7 +198,7 @@ class MPTBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config, quant_config)
self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
def forward(
self,