[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -108,12 +108,14 @@ class BloomAttention(nn.Module):
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# Create the alibi slopes and slice them.
@@ -152,6 +154,7 @@ class BloomMLP(nn.Module):
self,
config: BloomConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@@ -159,12 +162,14 @@ class BloomMLP(nn.Module):
hidden_size,
4 * hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -192,7 +197,7 @@ class BloomBlock(nn.Module):
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon
)
self.mlp = BloomMLP(config, quant_config)
self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm
)