[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-11-07 19:38:38 +08:00
committed by GitHub
parent 7bdb42b2f2
commit 1958bda9b4
26 changed files with 190 additions and 25 deletions

View File

@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
super().__init__()
self.A = ColumnParallelLinear(
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
input_dim,
rank,
bias=False,
quant_config=quant_config,
gather_output=True,
prefix=f"{prefix}.A",
)
if isinstance(output_dim, list):
@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
self.total_num_attention_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.attention_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Even though in Zamba2 weights are shared between attention layers, KV
@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_q_adapter",
)
linear_k_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_k_adapter",
)
linear_v_adapter = Zamba2LoRA(
self.attention_hidden_size,
config.adapter_rank,
self.attention_hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.linear_v_adapter",
)
else:
linear_q_adapter = nn.Identity()
@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
2 * [self.intermediate_size], # 2x for gate and input projections
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
self.hidden_size,
bias=self.config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
# Only allow GELU activations
@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
bare_block_idx=bare_block_idx,
num_hybrid_layers=num_hybrid_layers,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)
# Initialize layer normalizations
@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.linear",
)
self.mamba_decoder = Zamba2MambaDecoderLayer(
config,