[Models] Lfm2Moe: minor name changes for resolving lora conflicts (#29063)
Signed-off-by: Paul Pak <paulpak58@gmail.com>
This commit is contained in:
@@ -248,7 +248,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.conv = ShortConv(
|
||||
self.short_conv = ShortConv(
|
||||
config=config,
|
||||
dim=config.conv_dim,
|
||||
layer_idx=layer_idx,
|
||||
@@ -281,7 +281,7 @@ class Lfm2ShortConvDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.operator_norm(hidden_states, residual)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.conv(
|
||||
self.short_conv(
|
||||
hidden_states,
|
||||
output,
|
||||
)
|
||||
@@ -380,6 +380,9 @@ class Lfm2Model(nn.Module):
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ".conv." in name:
|
||||
name = name.replace(".conv.", ".short_conv.", 1)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@@ -414,6 +417,7 @@ class Lfm2ForCausalLM(
|
||||
"w1",
|
||||
"w3",
|
||||
],
|
||||
"in_proj": ["in_proj"],
|
||||
}
|
||||
|
||||
# LoRA specific attributes
|
||||
|
||||
Reference in New Issue
Block a user