[Bugfix] Fix missing lora name mapping for lora without prefix (#17793)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import NamedTuple, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -9,52 +10,96 @@ from torch import nn
|
||||
|
||||
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
|
||||
class LoRANameParserTestConfig(NamedTuple):
|
||||
name: str
|
||||
module_name: str
|
||||
is_lora_a: bool
|
||||
is_bias: bool
|
||||
weights_mapper: Optional[WeightsMapper] = None
|
||||
|
||||
|
||||
def test_parse_fine_tuned_lora_name_valid():
|
||||
fixture = {
|
||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
|
||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
|
||||
(
|
||||
fixture = [
|
||||
LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
|
||||
"lm_head", True, False),
|
||||
LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
|
||||
"lm_head", False, False),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||
"model.embed_tokens",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||
"model.embed_tokens",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LoRANameParserTestConfig(
|
||||
"language_model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"language_model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
LoRANameParserTestConfig(
|
||||
"language_model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"language_model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
}
|
||||
for name, module_name, is_lora_a, is_bias in fixture:
|
||||
# Test with WeightsMapper
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"language_model.model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"language_model.model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"language_model.model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"language_model.model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
),
|
||||
]
|
||||
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
|
||||
assert (module_name, is_lora_a,
|
||||
is_bias) == parse_fine_tuned_lora_name(name)
|
||||
is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)
|
||||
|
||||
|
||||
def test_parse_fine_tuned_lora_name_invalid():
|
||||
|
||||
Reference in New Issue
Block a user