[Misc][LoRA] Improve the readability of LoRA error messages (#12102)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
@@ -30,11 +31,14 @@ def test_load_checkpoints(
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
if lora_name == "baichuan7B":
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
# For the baichuan7B model, load it's LoRA,
|
||||
# and the test should pass.
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@@ -43,9 +47,12 @@ def test_load_checkpoints(
|
||||
# Test that the target_modules contain prefix
|
||||
# such as "model.layers.0.self_atten.W_pack", and
|
||||
# the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_zero_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@@ -53,9 +60,12 @@ def test_load_checkpoints(
|
||||
elif lora_name == "baichuan7B-zero-regex":
|
||||
# Test that the `target_modules` in the form of regular expressions,
|
||||
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
LoRAModel.from_local_checkpoint(
|
||||
baichuan_regex_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@@ -64,10 +74,13 @@ def test_load_checkpoints(
|
||||
# For the baichuan7B model, load chatglm3-6b's LoRA,
|
||||
# and the test should raise the following error.
|
||||
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
|
||||
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
LoRAModel.from_local_checkpoint(
|
||||
chatglm3_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
@@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
|
||||
".layers.": ".baichuan_layers.",
|
||||
},
|
||||
)
|
||||
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
peft_helper=peft_helper,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
|
||||
Reference in New Issue
Block a user