[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -56,13 +56,22 @@ class QWenMLP(nn.Module):
intermediate_size: int,
hidden_act: str = "silu",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.c_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
if hidden_act != "silu":
raise ValueError(
@@ -163,7 +172,10 @@ class QWenBlock(nn.Module):
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(
config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
config.hidden_size,
config.intermediate_size // 2,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(