[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
scaling = self.head_size**-0.5
@@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module):
def __init__(
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
self,
config: PhiConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
config.hidden_size,
n_inner,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
n_inner,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
self.act = get_act_fn(config.hidden_act)
@@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
self.self_attn = PhiAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
)
self.mlp = PhiMLP(config, quant_config)
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,