[Experimental] Add multi-LoRA support (#1804)

Co-authored-by: Chen Shen <scv119@gmail.com>
Co-authored-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-01-24 00:26:37 +01:00
committed by GitHub
parent 9c1352eb57
commit 9b945daaf1
52 changed files with 8035 additions and 126 deletions

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
SchedulerConfig, LoRAConfig)
@dataclass
@@ -35,6 +35,12 @@ class EngineArgs:
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
def __post_init__(self):
if self.tokenizer is None:
@@ -202,6 +208,39 @@ class EngineArgs:
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16', 'float32'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
return parser
@classmethod
@@ -214,7 +253,8 @@ class EngineArgs:
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
@@ -234,7 +274,14 @@ class EngineArgs:
self.max_num_seqs,
model_config.max_model_len,
self.max_paddings)
return model_config, cache_config, parallel_config, scheduler_config
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return model_config, cache_config, parallel_config, scheduler_config, lora_config
@dataclass