Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,11 @@ import pytest
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
from torch import nn
|
||||
|
||||
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.lora.utils import (
|
||||
get_adapter_absolute_path,
|
||||
parse_fine_tuned_lora_name,
|
||||
replace_submodule,
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
|
||||
@@ -24,10 +27,12 @@ class LoRANameParserTestConfig(NamedTuple):
|
||||
|
||||
def test_parse_fine_tuned_lora_name_valid():
|
||||
fixture = [
|
||||
LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
|
||||
"lm_head", True, False),
|
||||
LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
|
||||
"lm_head", False, False),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.lm_head.lora_A.weight", "lm_head", True, False
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.lm_head.lora_B.weight", "lm_head", False, False
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||
"model.embed_tokens",
|
||||
@@ -71,7 +76,8 @@ def test_parse_fine_tuned_lora_name_valid():
|
||||
True,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
orig_to_new_prefix={"model.": "language_model.model."}
|
||||
),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
@@ -79,7 +85,8 @@ def test_parse_fine_tuned_lora_name_valid():
|
||||
False,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
orig_to_new_prefix={"model.": "language_model.model."}
|
||||
),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
@@ -87,7 +94,8 @@ def test_parse_fine_tuned_lora_name_valid():
|
||||
True,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
orig_to_new_prefix={"model.": "language_model.model."}
|
||||
),
|
||||
),
|
||||
LoRANameParserTestConfig(
|
||||
"model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
@@ -95,12 +103,14 @@ def test_parse_fine_tuned_lora_name_valid():
|
||||
False,
|
||||
False,
|
||||
weights_mapper=WeightsMapper(
|
||||
orig_to_new_prefix={"model.": "language_model.model."}),
|
||||
orig_to_new_prefix={"model.": "language_model.model."}
|
||||
),
|
||||
),
|
||||
]
|
||||
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
|
||||
assert (module_name, is_lora_a,
|
||||
is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)
|
||||
assert (module_name, is_lora_a, is_bias) == parse_fine_tuned_lora_name(
|
||||
name, weights_mapper
|
||||
)
|
||||
|
||||
|
||||
def test_parse_fine_tuned_lora_name_invalid():
|
||||
@@ -115,22 +125,28 @@ def test_parse_fine_tuned_lora_name_invalid():
|
||||
|
||||
def test_replace_submodule():
|
||||
model = nn.Sequential(
|
||||
OrderedDict([
|
||||
("dense1", nn.Linear(764, 100)),
|
||||
("act1", nn.ReLU()),
|
||||
("dense2", nn.Linear(100, 50)),
|
||||
(
|
||||
"seq1",
|
||||
nn.Sequential(
|
||||
OrderedDict([
|
||||
("dense1", nn.Linear(100, 10)),
|
||||
("dense2", nn.Linear(10, 50)),
|
||||
])),
|
||||
),
|
||||
("act2", nn.ReLU()),
|
||||
("output", nn.Linear(50, 10)),
|
||||
("outact", nn.Sigmoid()),
|
||||
]))
|
||||
OrderedDict(
|
||||
[
|
||||
("dense1", nn.Linear(764, 100)),
|
||||
("act1", nn.ReLU()),
|
||||
("dense2", nn.Linear(100, 50)),
|
||||
(
|
||||
"seq1",
|
||||
nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("dense1", nn.Linear(100, 10)),
|
||||
("dense2", nn.Linear(10, 50)),
|
||||
]
|
||||
)
|
||||
),
|
||||
),
|
||||
("act2", nn.ReLU()),
|
||||
("output", nn.Linear(50, 10)),
|
||||
("outact", nn.Sigmoid()),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
sigmoid = nn.Sigmoid()
|
||||
|
||||
@@ -143,52 +159,51 @@ def test_replace_submodule():
|
||||
|
||||
|
||||
# Unit tests for get_adapter_absolute_path
|
||||
@patch('os.path.isabs')
|
||||
@patch("os.path.isabs")
|
||||
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
||||
path = '/absolute/path/to/lora'
|
||||
path = "/absolute/path/to/lora"
|
||||
mock_isabs.return_value = True
|
||||
assert get_adapter_absolute_path(path) == path
|
||||
|
||||
|
||||
@patch('os.path.expanduser')
|
||||
@patch("os.path.expanduser")
|
||||
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
|
||||
# Path with ~ that needs to be expanded
|
||||
path = '~/relative/path/to/lora'
|
||||
absolute_path = '/home/user/relative/path/to/lora'
|
||||
path = "~/relative/path/to/lora"
|
||||
absolute_path = "/home/user/relative/path/to/lora"
|
||||
mock_expanduser.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.abspath')
|
||||
@patch("os.path.exists")
|
||||
@patch("os.path.abspath")
|
||||
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
|
||||
# Relative path that exists locally
|
||||
path = 'relative/path/to/lora'
|
||||
absolute_path = '/absolute/path/to/lora'
|
||||
path = "relative/path/to/lora"
|
||||
absolute_path = "/absolute/path/to/lora"
|
||||
mock_exist.return_value = True
|
||||
mock_abspath.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('huggingface_hub.snapshot_download')
|
||||
@patch('os.path.exists')
|
||||
def test_get_adapter_absolute_path_huggingface(mock_exist,
|
||||
mock_snapshot_download):
|
||||
@patch("huggingface_hub.snapshot_download")
|
||||
@patch("os.path.exists")
|
||||
def test_get_adapter_absolute_path_huggingface(mock_exist, mock_snapshot_download):
|
||||
# Hugging Face model identifier
|
||||
path = 'org/repo'
|
||||
absolute_path = '/mock/snapshot/path'
|
||||
path = "org/repo"
|
||||
absolute_path = "/mock/snapshot/path"
|
||||
mock_exist.return_value = False
|
||||
mock_snapshot_download.return_value = absolute_path
|
||||
assert get_adapter_absolute_path(path) == absolute_path
|
||||
|
||||
|
||||
@patch('huggingface_hub.snapshot_download')
|
||||
@patch('os.path.exists')
|
||||
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
|
||||
mock_snapshot_download):
|
||||
@patch("huggingface_hub.snapshot_download")
|
||||
@patch("os.path.exists")
|
||||
def test_get_adapter_absolute_path_huggingface_error(
|
||||
mock_exist, mock_snapshot_download
|
||||
):
|
||||
# Hugging Face model identifier with download error
|
||||
path = 'org/repo'
|
||||
path = "org/repo"
|
||||
mock_exist.return_value = False
|
||||
mock_snapshot_download.side_effect = HfHubHTTPError(
|
||||
"failed to query model info")
|
||||
mock_snapshot_download.side_effect = HfHubHTTPError("failed to query model info")
|
||||
assert get_adapter_absolute_path(path) == path
|
||||
|
||||
Reference in New Issue
Block a user