[Bugfix] Fix Qwen2-VL LoRA weight loading (#11430)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
lora_lst = [
|
||||
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
|
||||
@@ -71,3 +72,32 @@ def test_load_checkpoints(
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
|
||||
|
||||
def test_lora_weights_mapping(baichuan_lora_files, ):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
}, )
|
||||
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules,
|
||||
weights_mapper=hf_to_vllm_mapper,
|
||||
)
|
||||
for name in lora_model.loras:
|
||||
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
|
||||
|
||||
Reference in New Issue
Block a user