[Misc][LoRA] Support Rank Stabilized LoRA (RSLoRA) (#6909)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import math
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from vllm.utils import print_info_once
|
||||
|
||||
|
||||
@dataclass
|
||||
class PEFTHelper:
|
||||
@@ -14,21 +16,22 @@ class PEFTHelper:
|
||||
|
||||
bias: Literal["none", "all", "lora_only"] = field(default="none")
|
||||
modules_to_save: Optional[list[str]] = field(default=None)
|
||||
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
|
||||
use_rslora: bool = field(default=False)
|
||||
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
|
||||
use_dora: bool = field(default=False)
|
||||
# long lora field
|
||||
# long context lora field
|
||||
context_length: int = field(default=0)
|
||||
# Extra vllm field, start with 'vllm_' to avoid conflict
|
||||
vllm_lora_scaling_factor: float = field(default=1.0)
|
||||
vllm_max_position_embeddings: Optional[int] = field(default=False)
|
||||
vllm_scaling_factor: Optional[float] = field(default=None)
|
||||
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
|
||||
|
||||
def _validate_features(self):
|
||||
error_msg = []
|
||||
|
||||
if self.modules_to_save:
|
||||
error_msg.append("vLLM only supports modules_to_save being None.")
|
||||
if self.use_rslora:
|
||||
error_msg.append("vLLM does not yet support RSLoRA.")
|
||||
|
||||
if self.use_dora:
|
||||
error_msg.append("vLLM does not yet support DoRA.")
|
||||
@@ -38,10 +41,15 @@ class PEFTHelper:
|
||||
|
||||
def __post_init__(self):
|
||||
self._validate_features()
|
||||
if self.use_rslora:
|
||||
print_info_once("Loading LoRA weights trained with rsLoRA.")
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
|
||||
else:
|
||||
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
|
||||
if self.context_length:
|
||||
if self.vllm_max_position_embeddings is None:
|
||||
self.vllm_max_position_embeddings = self.context_length
|
||||
self.vllm_scaling_factor = float(
|
||||
self.vllm_long_context_scaling_factor = float(
|
||||
math.ceil(self.context_length /
|
||||
self.vllm_max_position_embeddings))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user