[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -98,13 +98,22 @@ class BaiChuanMLP(nn.Module):
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
@@ -152,12 +161,14 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.W_pack",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Create the alibi slopes and slice them.
if self.position_embedding == "ALIBI":
@@ -235,6 +246,7 @@ class BaiChuanDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(