[Misc][Model][Refactor] Pass the prefix into Linear layers (#28259)
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -75,7 +75,12 @@ class Zamba2LoRA(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.A = ColumnParallelLinear(
|
||||
input_dim, rank, bias=False, quant_config=quant_config, gather_output=True
|
||||
input_dim,
|
||||
rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix=f"{prefix}.A",
|
||||
)
|
||||
|
||||
if isinstance(output_dim, list):
|
||||
@@ -150,12 +155,14 @@ class Zamba2Attention(nn.Module):
|
||||
self.total_num_attention_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.attention_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
# Even though in Zamba2 weights are shared between attention layers, KV
|
||||
@@ -197,18 +204,21 @@ class Zamba2Attention(nn.Module):
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_q_adapter",
|
||||
)
|
||||
linear_k_adapter = Zamba2LoRA(
|
||||
self.attention_hidden_size,
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_k_adapter",
|
||||
)
|
||||
linear_v_adapter = Zamba2LoRA(
|
||||
self.attention_hidden_size,
|
||||
config.adapter_rank,
|
||||
self.attention_hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear_v_adapter",
|
||||
)
|
||||
else:
|
||||
linear_q_adapter = nn.Identity()
|
||||
@@ -312,6 +322,7 @@ class Zamba2MLP(nn.Module):
|
||||
2 * [self.intermediate_size], # 2x for gate and input projections
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
|
||||
self.down_proj = RowParallelLinear(
|
||||
@@ -319,6 +330,7 @@ class Zamba2MLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=self.config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
# Only allow GELU activations
|
||||
@@ -418,6 +430,7 @@ class Zamba2AttentionDecoderLayer(nn.Module):
|
||||
bare_block_idx=bare_block_idx,
|
||||
num_hybrid_layers=num_hybrid_layers,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
|
||||
# Initialize layer normalizations
|
||||
@@ -599,6 +612,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.linear",
|
||||
)
|
||||
self.mamba_decoder = Zamba2MambaDecoderLayer(
|
||||
config,
|
||||
|
||||
Reference in New Issue
Block a user