[Misc][LoRA] Support loading LoRA weights for target_modules in reg format (#9275)

This commit is contained in:
Jee Jee Li
2024-10-11 20:31:21 +08:00
committed by GitHub
parent e808156f30
commit 36ea79079b
4 changed files with 59 additions and 5 deletions

View File

@@ -5,7 +5,9 @@ import pytest
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
lora_lst = [
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
]
@pytest.mark.parametrize("lora_name", lora_lst)
@@ -13,6 +15,7 @@ def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
baichuan_regex_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
@@ -36,7 +39,7 @@ def test_load_checkpoints(
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix
# Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
@@ -46,6 +49,16 @@ def test_load_checkpoints(
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero-regex":
# Test that the `target_modules` in the form of regular expressions,
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_regex_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.