[Models] LFM2: Support LoRA (#34921)

Co-authored-by: Piotr Mazurek <piotr635@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
tianshu-Michael-yu
2026-02-19 22:07:23 -08:00
committed by GitHub
parent f5432e35a3
commit ea37530b47
2 changed files with 36 additions and 16 deletions

View File

@@ -39,6 +39,7 @@ from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, Suppo
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
WeightsMapper,
extract_layer_index, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
@@ -66,12 +67,12 @@ class Lfm2MLP(nn.Module):
ff_dim = int(ffn_dim_multiplier * ff_dim) ff_dim = int(ffn_dim_multiplier * ff_dim)
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
self.w1 = MergedColumnParallelLinear( self.w13 = MergedColumnParallelLinear(
input_size=dim, input_size=dim,
output_sizes=[ff_dim] * 2, output_sizes=[ff_dim] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.w1", prefix=f"{prefix}.w13",
) )
self.w2 = RowParallelLinear( self.w2 = RowParallelLinear(
input_size=ff_dim, input_size=ff_dim,
@@ -83,7 +84,7 @@ class Lfm2MLP(nn.Module):
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.w1(x) gate_up, _ = self.w13(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.w2(x) x, _ = self.w2(x)
return x return x
@@ -376,8 +377,8 @@ class Lfm2Model(nn.Module):
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
(".w1", ".w1", 0), (".w13", ".w1", 0),
(".w1", ".w3", 1), (".w13", ".w3", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
@@ -386,9 +387,11 @@ class Lfm2Model(nn.Module):
name = name.replace(".conv.", ".short_conv.", 1) name = name.replace(".conv.", ".short_conv.", 1)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: # Use segment-boundary matching (trailing dot) to prevent
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
if weight_name + "." not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name + ".", param_name + ".")
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
@@ -415,13 +418,20 @@ class Lfm2ForCausalLM(
"k_proj", "k_proj",
"v_proj", "v_proj",
], ],
"w1": [ "w13": [
"w1", "w1",
"w3", "w3",
], ],
"in_proj": ["in_proj"], "in_proj": ["in_proj"],
} }
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
# with the inner .conv.conv child (ShortConv has a child self.conv, so
# naming the container .conv too makes _match_target_modules match both)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".conv.": ".short_conv."},
)
# LoRA specific attributes # LoRA specific attributes
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",

View File

@@ -52,6 +52,7 @@ from .interfaces import (
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
PPMissingLayer, PPMissingLayer,
WeightsMapper,
extract_layer_index, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
@@ -69,12 +70,12 @@ class Lfm2MoeMlp(nn.Module):
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
self.w1 = MergedColumnParallelLinear( self.w13 = MergedColumnParallelLinear(
input_size=dim, input_size=dim,
output_sizes=[ff_dim] * 2, output_sizes=[ff_dim] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.w1", prefix=f"{prefix}.w13",
) )
self.w2 = RowParallelLinear( self.w2 = RowParallelLinear(
input_size=ff_dim, input_size=ff_dim,
@@ -86,7 +87,7 @@ class Lfm2MoeMlp(nn.Module):
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.w1(x) gate_up, _ = self.w13(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.w2(x) x, _ = self.w2(x)
return x return x
@@ -501,8 +502,8 @@ class Lfm2MoeModel(nn.Module):
(".qkv_proj", ".q_proj", "q"), (".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"), (".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"), (".qkv_proj", ".v_proj", "v"),
(".w1", ".w1", 0), (".w13", ".w1", 0),
(".w1", ".w3", 1), (".w13", ".w3", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: set[str] = set() loaded_params: set[str] = set()
@@ -516,12 +517,14 @@ class Lfm2MoeModel(nn.Module):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
if weight_name not in name: # Use segment-boundary matching (trailing dot) to prevent
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
if weight_name + "." not in name:
continue continue
if ("feed_forward.experts." in name) and name not in params_dict: if ("feed_forward.experts." in name) and name not in params_dict:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name + ".", param_name + ".")
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if ( if (
name.endswith(".bias") or name.endswith("_bias") name.endswith(".bias") or name.endswith("_bias")
@@ -596,13 +599,20 @@ class Lfm2MoeForCausalLM(
"k_proj", "k_proj",
"v_proj", "v_proj",
], ],
"w1": [ "w13": [
"w1", "w1",
"w3", "w3",
], ],
"in_proj": ["in_proj"], "in_proj": ["in_proj"],
} }
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
# with the inner .conv.conv child (ShortConv has a child self.conv, so
# naming the container .conv too makes _match_target_modules match both)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".conv.": ".short_conv."},
)
# LoRA specific attributes # LoRA specific attributes
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",