[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -99,11 +99,13 @@ class PhiAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
@@ -148,7 +150,10 @@ class PhiAttention(nn.Module):
|
||||
|
||||
class PhiMLP(nn.Module):
|
||||
def __init__(
|
||||
self, config: PhiConfig, quant_config: QuantizationConfig | None = None
|
||||
self,
|
||||
config: PhiConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -159,11 +164,13 @@ class PhiMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
n_inner,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
n_inner,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc2",
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
@@ -189,7 +196,7 @@ class PhiLayer(nn.Module):
|
||||
self.self_attn = PhiAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
|
||||
)
|
||||
self.mlp = PhiMLP(config, quant_config)
|
||||
self.mlp = PhiMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user