[Misc][LoRA] Improve the readability of LoRA error messages (#12102)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -11,6 +14,12 @@ logger = init_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
"""
|
||||
A helper class for PEFT configurations, specifically designed for LoRA.
|
||||
This class handles configuration validation, compatibility checks for
|
||||
various LoRA implementations.
|
||||
"""
|
||||
|
||||
# Required fields
|
||||
r: int
|
||||
lora_alpha: int
|
||||
@@ -29,20 +38,18 @@ class PEFTHelper:
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self):
|
||||
def _validate_features(self) -> List[str]:
|
||||
"""
|
||||
Check if there are any unsupported Lora features.
|
||||
"""
|
||||
error_msg = []
|
||||
|
||||
if self.modules_to_save:
|
||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||
|
||||
if self.use_dora:
|
||||
error_msg.append("vLLM does not yet support DoRA.")
|
||||
|
||||
if error_msg:
|
||||
raise ValueError(f"{', '.join(error_msg)}")
|
||||
return error_msg
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_features()
|
||||
if self.use_rslora:
|
||||
logger.info_once("Loading LoRA weights trained with rsLoRA.")
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||
@@ -78,3 +85,29 @@ class PEFTHelper:
|
||||
for k, v in config_dict.items() if k in class_fields
|
||||
}
|
||||
return cls(**filtered_dict)
|
||||
|
||||
@classmethod
|
||||
def from_local_dir(cls, lora_path: str,
|
||||
max_position_embeddings: Optional[int]) -> "PEFTHelper":
|
||||
lora_config_path = os.path.join(lora_path, "adapter_config.json")
|
||||
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
config["vllm_max_position_embeddings"] = max_position_embeddings
|
||||
return cls.from_dict(config)
|
||||
|
||||
def validate_legal(self, lora_config: LoRAConfig) -> None:
|
||||
"""
|
||||
Validates the LoRA configuration settings against application
|
||||
constraints and requirements.
|
||||
"""
|
||||
error_msg = self._validate_features()
|
||||
if self.r > lora_config.max_lora_rank:
|
||||
error_msg.append(
|
||||
f"LoRA rank {self.r} is greater than max_lora_rank"
|
||||
f" {lora_config.max_lora_rank}.")
|
||||
if self.bias != "none" and not lora_config.bias_enabled:
|
||||
error_msg.append(
|
||||
"Adapter bias cannot be used without bias_enabled.")
|
||||
if error_msg:
|
||||
raise ValueError(f"{' '.join(error_msg)}")
|
||||
|
||||
Reference in New Issue
Block a user