[Models] LFM2: Support LoRA (#34921)
Co-authored-by: Piotr Mazurek <piotr635@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
f5432e35a3
commit
ea37530b47
@@ -39,6 +39,7 @@ from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, Suppo
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
|
WeightsMapper,
|
||||||
extract_layer_index,
|
extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
@@ -66,12 +67,12 @@ class Lfm2MLP(nn.Module):
|
|||||||
ff_dim = int(ffn_dim_multiplier * ff_dim)
|
ff_dim = int(ffn_dim_multiplier * ff_dim)
|
||||||
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
|
||||||
|
|
||||||
self.w1 = MergedColumnParallelLinear(
|
self.w13 = MergedColumnParallelLinear(
|
||||||
input_size=dim,
|
input_size=dim,
|
||||||
output_sizes=[ff_dim] * 2,
|
output_sizes=[ff_dim] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.w1",
|
prefix=f"{prefix}.w13",
|
||||||
)
|
)
|
||||||
self.w2 = RowParallelLinear(
|
self.w2 = RowParallelLinear(
|
||||||
input_size=ff_dim,
|
input_size=ff_dim,
|
||||||
@@ -83,7 +84,7 @@ class Lfm2MLP(nn.Module):
|
|||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
gate_up, _ = self.w1(x)
|
gate_up, _ = self.w13(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.w2(x)
|
x, _ = self.w2(x)
|
||||||
return x
|
return x
|
||||||
@@ -376,8 +377,8 @@ class Lfm2Model(nn.Module):
|
|||||||
(".qkv_proj", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
(".w1", ".w1", 0),
|
(".w13", ".w1", 0),
|
||||||
(".w1", ".w3", 1),
|
(".w13", ".w3", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
@@ -386,9 +387,11 @@ class Lfm2Model(nn.Module):
|
|||||||
name = name.replace(".conv.", ".short_conv.", 1)
|
name = name.replace(".conv.", ".short_conv.", 1)
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
# Use segment-boundary matching (trailing dot) to prevent
|
||||||
|
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
|
||||||
|
if weight_name + "." not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name + ".", param_name + ".")
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
@@ -415,13 +418,20 @@ class Lfm2ForCausalLM(
|
|||||||
"k_proj",
|
"k_proj",
|
||||||
"v_proj",
|
"v_proj",
|
||||||
],
|
],
|
||||||
"w1": [
|
"w13": [
|
||||||
"w1",
|
"w1",
|
||||||
"w3",
|
"w3",
|
||||||
],
|
],
|
||||||
"in_proj": ["in_proj"],
|
"in_proj": ["in_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
|
||||||
|
# with the inner .conv.conv child (ShortConv has a child self.conv, so
|
||||||
|
# naming the container .conv too makes _match_target_modules match both)
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={".conv.": ".short_conv."},
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
embedding_modules = {
|
embedding_modules = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ from .interfaces import (
|
|||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
PPMissingLayer,
|
PPMissingLayer,
|
||||||
|
WeightsMapper,
|
||||||
extract_layer_index,
|
extract_layer_index,
|
||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory,
|
make_empty_intermediate_tensors_factory,
|
||||||
@@ -69,12 +70,12 @@ class Lfm2MoeMlp(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.w1 = MergedColumnParallelLinear(
|
self.w13 = MergedColumnParallelLinear(
|
||||||
input_size=dim,
|
input_size=dim,
|
||||||
output_sizes=[ff_dim] * 2,
|
output_sizes=[ff_dim] * 2,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.w1",
|
prefix=f"{prefix}.w13",
|
||||||
)
|
)
|
||||||
self.w2 = RowParallelLinear(
|
self.w2 = RowParallelLinear(
|
||||||
input_size=ff_dim,
|
input_size=ff_dim,
|
||||||
@@ -86,7 +87,7 @@ class Lfm2MoeMlp(nn.Module):
|
|||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
gate_up, _ = self.w1(x)
|
gate_up, _ = self.w13(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.w2(x)
|
x, _ = self.w2(x)
|
||||||
return x
|
return x
|
||||||
@@ -501,8 +502,8 @@ class Lfm2MoeModel(nn.Module):
|
|||||||
(".qkv_proj", ".q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
(".qkv_proj", ".k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
(".qkv_proj", ".v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
(".w1", ".w1", 0),
|
(".w13", ".w1", 0),
|
||||||
(".w1", ".w3", 1),
|
(".w13", ".w3", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: set[str] = set()
|
loaded_params: set[str] = set()
|
||||||
@@ -516,12 +517,14 @@ class Lfm2MoeModel(nn.Module):
|
|||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
# Skip non-stacked layers and experts (experts handled below).
|
# Skip non-stacked layers and experts (experts handled below).
|
||||||
if weight_name not in name:
|
# Use segment-boundary matching (trailing dot) to prevent
|
||||||
|
# e.g. ".w1" from matching inside ".w13" in pre-fused keys.
|
||||||
|
if weight_name + "." not in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if ("feed_forward.experts." in name) and name not in params_dict:
|
if ("feed_forward.experts." in name) and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name + ".", param_name + ".")
|
||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if (
|
if (
|
||||||
name.endswith(".bias") or name.endswith("_bias")
|
name.endswith(".bias") or name.endswith("_bias")
|
||||||
@@ -596,13 +599,20 @@ class Lfm2MoeForCausalLM(
|
|||||||
"k_proj",
|
"k_proj",
|
||||||
"v_proj",
|
"v_proj",
|
||||||
],
|
],
|
||||||
"w1": [
|
"w13": [
|
||||||
"w1",
|
"w1",
|
||||||
"w3",
|
"w3",
|
||||||
],
|
],
|
||||||
"in_proj": ["in_proj"],
|
"in_proj": ["in_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# HF uses .conv. but vLLM uses .short_conv. to avoid LoRA regex collision
|
||||||
|
# with the inner .conv.conv child (ShortConv has a child self.conv, so
|
||||||
|
# naming the container .conv too makes _match_target_modules match both)
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_substr={".conv.": ".short_conv."},
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
embedding_modules = {
|
embedding_modules = {
|
||||||
"embed_tokens": "input_embeddings",
|
"embed_tokens": "input_embeddings",
|
||||||
|
|||||||
Reference in New Issue
Block a user