[Misc][LoRA] Add PEFTHelper for LoRA (#11003)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -13,6 +14,7 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
@@ -30,18 +32,68 @@ CUDA_DEVICES = [
|
||||
]
|
||||
|
||||
|
||||
def test_peft_helper(sql_lora_files):
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
assert peft_helper.r == 8
|
||||
assert peft_helper.lora_alpha == 16
|
||||
assert peft_helper.target_modules == [
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
"embed_tokens",
|
||||
"lm_head",
|
||||
]
|
||||
|
||||
expected_error = "vLLM only supports modules_to_save being None."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
modules_to_save=["lm_head"],
|
||||
)
|
||||
PEFTHelper.from_dict(config)
|
||||
expected_error = "vLLM does not yet support RSLoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_rslora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
expected_error = "vLLM does not yet support DoRA."
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
config = dict(r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["gate_proj"],
|
||||
use_dora=True)
|
||||
PEFTHelper.from_dict(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
|
||||
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
peft_helper = PEFTHelper.from_dict(config)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
device,
|
||||
peft_helper=peft_helper,
|
||||
device=device,
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
|
||||
Reference in New Issue
Block a user