[Models] LFM2: Support LoRA (#34921)

Co-authored-by: Piotr Mazurek <piotr635@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
tianshu-Michael-yu
2026-02-19 22:07:23 -08:00
committed by GitHub
parent f5432e35a3
commit ea37530b47
2 changed files with 36 additions and 16 deletions

View File

@@ -39,6 +39,7 @@ from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, Suppo
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
@@ -66,12 +67,12 @@ class Lfm2MLP(nn.Module):
ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = MergedColumnParallelLinear(
self.w13 = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[ff_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w1",
prefix=f"{prefix}.w13",
)
self.w2 = RowParallelLinear(
input_size=ff_dim,
@@ -83,7 +84,7 @@ class Lfm2MLP(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.w1(x)
gate_up, _ = self.w13(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
@@ -376,8 +377,8 @@ class Lfm2Model(nn.Module):
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".w1", ".w1", 0),
(".w1", ".w3", 1),
(".w13", ".w1", 0),
(".w13", ".w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -386,9 +387,11 @@ class Lfm2Model(nn.Module):
name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
# Use segment-boundary matching (trailing dot) to prevent
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
if weight_name + "." not in name:
continue
name = name.replace(weight_name, param_name)
name = name.replace(weight_name + ".", param_name + ".")
if is_pp_missing_parameter(name, self):
continue
@@ -415,13 +418,20 @@ class Lfm2ForCausalLM(
"k_proj",
"v_proj",
],
"w1": [
"w13": [
"w1",
"w3",
],
"in_proj": ["in_proj"],
}
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
# with the inner .conv.conv child (ShortConv has a child self.conv, so
# naming the container .conv too makes _match_target_modules match both)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".conv.": ".short_conv."},
)
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",

View File

@@ -52,6 +52,7 @@ from .interfaces import (
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
@@ -69,12 +70,12 @@ class Lfm2MoeMlp(nn.Module):
prefix: str = "",
):
super().__init__()
self.w1 = MergedColumnParallelLinear(
self.w13 = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[ff_dim] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w1",
prefix=f"{prefix}.w13",
)
self.w2 = RowParallelLinear(
input_size=ff_dim,
@@ -86,7 +87,7 @@ class Lfm2MoeMlp(nn.Module):
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.w1(x)
gate_up, _ = self.w13(x)
x = self.act_fn(gate_up)
x, _ = self.w2(x)
return x
@@ -501,8 +502,8 @@ class Lfm2MoeModel(nn.Module):
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".w1", ".w1", 0),
(".w1", ".w3", 1),
(".w13", ".w1", 0),
(".w13", ".w3", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
@@ -516,12 +517,14 @@ class Lfm2MoeModel(nn.Module):
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
# Use segment-boundary matching (trailing dot) to prevent
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
if weight_name + "." not in name:
continue
if ("feed_forward.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
name = name.replace(weight_name + ".", param_name + ".")
# Skip loading extra bias for GPTQ models.
if (
name.endswith(".bias") or name.endswith("_bias")
@@ -596,13 +599,20 @@ class Lfm2MoeForCausalLM(
"k_proj",
"v_proj",
],
"w1": [
"w13": [
"w1",
"w3",
],
"in_proj": ["in_proj"],
}
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
# with the inner .conv.conv child (ShortConv has a child self.conv, so
# naming the container .conv too makes _match_target_modules match both)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".conv.": ".short_conv."},
)
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",