[Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through. It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors. Follow up of https://github.com/vllm-project/vllm/pull/3095/files
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
@@ -63,6 +63,7 @@ class EngineArgs:
|
||||
max_lora_rank: int = 16
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
lora_dtype = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'auto'
|
||||
@@ -397,6 +398,17 @@ class EngineArgs:
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help=('Data type for LoRA. If auto, will default to '
|
||||
'base model dtype.'))
|
||||
parser.add_argument(
|
||||
'--long-lora-scaling-factors',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.long_lora_scaling_factors,
|
||||
help=('Specify multiple scaling factors (which can '
|
||||
'be different from base model scaling factor '
|
||||
'- see eg. Long LoRA) to allow for multiple '
|
||||
'LoRA adapters trained with those scaling '
|
||||
'factors to be used at the same time. If not '
|
||||
'specified, only adapters trained with the '
|
||||
'base model scaling factor are allowed.'))
|
||||
parser.add_argument(
|
||||
'--max-cpu-loras',
|
||||
type=int,
|
||||
@@ -593,6 +605,7 @@ class EngineArgs:
|
||||
max_loras=self.max_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
|
||||
Reference in New Issue
Block a user