[Misc][LoRA] Improve the readability of LoRA error messages (#12102)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-01-17 19:32:28 +08:00
committed by GitHub
parent 69d765f5a5
commit 07934cc237
10 changed files with 243 additions and 114 deletions

View File

@@ -1,9 +1,12 @@
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union
from typing import List, Literal, Optional, Union
from vllm.config import LoRAConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -11,6 +14,12 @@ logger = init_logger(__name__)
@dataclass
class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""
# Required fields
r: int
lora_alpha: int
@@ -29,20 +38,18 @@ class PEFTHelper:
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self):
def _validate_features(self) -> List[str]:
"""
Check if there are any unsupported Lora features.
"""
error_msg = []
if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
if error_msg:
raise ValueError(f"{', '.join(error_msg)}")
return error_msg
def __post_init__(self):
self._validate_features()
if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
@@ -78,3 +85,29 @@ class PEFTHelper:
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)
@classmethod
def from_local_dir(cls, lora_path: str,
max_position_embeddings: Optional[int]) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)
def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}.")
if self.bias != "none" and not lora_config.bias_enabled:
error_msg.append(
"Adapter bias cannot be used without bias_enabled.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")