[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -107,12 +107,14 @@ class JAISAttention(nn.Module):
total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_attn",
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
tp_rank = get_tensor_model_parallel_rank()
@@ -147,6 +149,7 @@ class JAISMLP(nn.Module):
intermediate_size: int,
config: JAISConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
hidden_size = config.hidden_size
@@ -156,6 +159,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.c_fc2 = (
ColumnParallelLinear(
@@ -163,6 +167,7 @@ class JAISMLP(nn.Module):
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc2",
)
if self.swiglu
else None
@@ -172,6 +177,7 @@ class JAISMLP(nn.Module):
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
self.act = SwiGLUActivation()
@@ -206,7 +212,7 @@ class JAISBlock(nn.Module):
config, cache_config, quant_config, prefix=f"{prefix}.attn"
)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = JAISMLP(inner_dim, config, quant_config)
self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,