[Misc] Improve error message when LoRA parsing fails (#5194)

This commit is contained in:
Cyrus Leung
2024-06-10 19:38:49 +08:00
committed by GitHub
parent c81da5f56d
commit 0bfa1c4f13
2 changed files with 20 additions and 9 deletions

View File

@@ -1,12 +1,13 @@
from collections import OrderedDict
import pytest
from torch import nn
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
from vllm.utils import LRUCache
def test_parse_fine_tuned_lora_name():
def test_parse_fine_tuned_lora_name_valid():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
@@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name():
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
def test_parse_fine_tuned_lora_name_invalid():
fixture = {
"weight",
"base_model.weight",
"base_model.model.weight",
}
for name in fixture:
with pytest.raises(ValueError, match="unsupported LoRA weight"):
parse_fine_tuned_lora_name(name)
def test_replace_submodule():
model = nn.Sequential(
OrderedDict([