[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple, Union
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig,
|
||||
TokenizerPoolConfig)
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -66,6 +66,9 @@ class EngineArgs:
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = 1
|
||||
max_prompt_adapter_token: int = 0
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
@@ -449,6 +452,17 @@ class EngineArgs:
|
||||
'Enabling this will use the fully sharded layers. '
|
||||
'At high sequence length, max rank or '
|
||||
'tensor parallel size, this is likely faster.'))
|
||||
parser.add_argument('--enable-prompt-adapter',
|
||||
action='store_true',
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
parser.add_argument('--max-prompt-adapters',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapters,
|
||||
help='Max number of PromptAdapters in a batch.')
|
||||
parser.add_argument('--max-prompt-adapter-token',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapter_token,
|
||||
help='Max number of PromptAdapters tokens')
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
@@ -726,6 +740,11 @@ class EngineArgs:
|
||||
model_loader_extra_config=self.model_loader_extra_config,
|
||||
)
|
||||
|
||||
prompt_adapter_config = PromptAdapterConfig(
|
||||
max_prompt_adapters=self.max_prompt_adapters,
|
||||
max_prompt_adapter_token=self.max_prompt_adapter_token) \
|
||||
if self.enable_prompt_adapter else None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
@@ -751,6 +770,7 @@ class EngineArgs:
|
||||
load_config=load_config,
|
||||
decoding_config=decoding_config,
|
||||
observability_config=observability_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user