[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -56,13 +56,22 @@ class QWenMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -163,7 +172,10 @@ class QWenBlock(nn.Module):
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(
|
||||
config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
|
||||
config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user