[CORE] Adding support for insertion of soft-tuned prompts (#4645)

Co-authored-by: Swapnil Parekh <swapnilp@ibm.com>
Co-authored-by: Joe G <joseph.granados@h2o.ai>
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Swapnil Parekh
2024-07-09 16:26:36 -04:00
committed by GitHub
parent a0550cbc80
commit 4d6ada947c
48 changed files with 1952 additions and 519 deletions

View File

@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig)
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@@ -66,6 +66,9 @@ class EngineArgs:
enable_lora: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
@@ -449,6 +452,17 @@ class EngineArgs:
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
@@ -726,6 +740,11 @@ class EngineArgs:
model_loader_extra_config=self.model_loader_extra_config,
)
prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters,
max_prompt_adapter_token=self.max_prompt_adapter_token) \
if self.enable_prompt_adapter else None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
@@ -751,6 +770,7 @@ class EngineArgs:
load_config=load_config,
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
)