[Misc][LoRA] Add PEFTHelper for LoRA (#11003)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-12-10 19:12:01 +08:00
committed by GitHub
parent beb16b2c81
commit d05f88679b
4 changed files with 160 additions and 28 deletions

View File

@@ -4,6 +4,7 @@ from typing import Sequence as GenericSequence
import torch
import torch.types
from vllm.lora.peft_helper import PEFTHelper
from vllm.utils import is_pin_memory_available
@@ -59,6 +60,23 @@ class LoRALayerWeights:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
None,
embeddings_tensor,
)
@classmethod
def create_dummy_lora_weights(
cls,

View File

@@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLora,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
@@ -104,14 +105,12 @@ class LoRAModel(AdapterModel):
def from_lora_tensors(
cls,
lora_model_id: int,
rank: int,
lora_alpha: int,
tensors: Dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[Dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
scaling_factor: Optional[float] = None,
embedding_modules: Optional[Dict[str, str]] = None,
embedding_padding_modules: Optional[List[str]] = None,
) -> "LoRAModel":
@@ -135,10 +134,9 @@ class LoRAModel(AdapterModel):
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
None,
lora_embeddings_tensor)
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor)
if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
@@ -170,7 +168,11 @@ class LoRAModel(AdapterModel):
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_scaling_factor)
@classmethod
def from_local_checkpoint(
@@ -212,6 +214,9 @@ class LoRAModel(AdapterModel):
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
peft_helper = PEFTHelper.from_dict(config)
if os.path.isfile(lora_tensor_path):
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
@@ -242,7 +247,7 @@ class LoRAModel(AdapterModel):
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
@@ -256,7 +261,7 @@ class LoRAModel(AdapterModel):
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
config["target_modules"], expected_lora_modules):
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
@@ -274,30 +279,17 @@ class LoRAModel(AdapterModel):
embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device)
rank = config["r"]
lora_alpha = config["lora_alpha"]
context_length = config.get("context_length", None)
scaling_factor = None
if context_length:
if max_position_embeddings is None:
max_position_embeddings = context_length
scaling_factor = float(
math.ceil(context_length / max_position_embeddings))
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
rank=rank,
lora_alpha=lora_alpha,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
scaling_factor=scaling_factor,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
)
embedding_padding_modules=embedding_padding_modules)
class LoRAModelManager(AdapterModelManager):

70
vllm/lora/peft_helper.py Normal file
View File

@@ -0,0 +1,70 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import math
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union
@dataclass
class PEFTHelper:
# Required fields
r: int
lora_alpha: int
target_modules: Union[list[str], str]
bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
use_rslora: bool = field(default=False)
use_dora: bool = field(default=False)
# long lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self):
error_msg = []
if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_rslora:
error_msg.append("vLLM does not yet support RSLoRA.")
if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
if error_msg:
raise ValueError(f"{', '.join(error_msg)}")
def __post_init__(self):
self._validate_features()
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))
@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}
# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(
f"Missing required configuration fields: {missing_fields}")
# Filter out fields that aren't defined in the class
filtered_dict = {
k: v
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)