[CORE] Adding support for insertion of soft-tuned prompts (#4645)
Co-authored-by: Swapnil Parekh <swapnilp@ibm.com> Co-authored-by: Joe G <joseph.granados@h2o.ai> Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -1285,6 +1285,39 @@ class LoRAConfig:
|
||||
raise ValueError("LoRA is not supported with chunked prefill yet.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptAdapterConfig:
|
||||
max_prompt_adapters: int
|
||||
max_prompt_adapter_token: int
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||
|
||||
def __post_init__(self):
|
||||
library_name = 'peft'
|
||||
try:
|
||||
__import__(library_name)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"'{library_name}' is not installed for prompt adapter support."
|
||||
f"Please install it using 'pip install {library_name}'."
|
||||
) from e
|
||||
|
||||
if self.max_prompt_adapters < 1:
|
||||
raise ValueError(f"max_prompt_adapters "
|
||||
f"({self.max_prompt_adapters}) must be >= 1.")
|
||||
if self.max_prompt_adapter_token == 0:
|
||||
raise ValueError("max_prompt_adapter_token must be set.")
|
||||
if self.max_cpu_prompt_adapters is None:
|
||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.prompt_adapter_dtype in (None, "auto"):
|
||||
self.prompt_adapter_dtype = model_config.dtype
|
||||
elif isinstance(self.prompt_adapter_dtype, str):
|
||||
self.prompt_adapter_dtype = getattr(torch,
|
||||
self.prompt_adapter_dtype)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalConfig:
|
||||
"""Configs the input data format and how models should run for
|
||||
@@ -1518,6 +1551,7 @@ class EngineConfig:
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
observability_config: Optional[ObservabilityConfig]
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@@ -1529,6 +1563,9 @@ class EngineConfig:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
if self.prompt_adapter_config:
|
||||
self.prompt_adapter_config.verify_with_model_config(
|
||||
self.model_config)
|
||||
|
||||
def to_dict(self):
|
||||
"""Return the configs as a dictionary, for use in **kwargs.
|
||||
|
||||
Reference in New Issue
Block a user