Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2565,18 +2565,41 @@ class SpeculativeConfig:
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
|
||||
|
||||
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
max_loras: int
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
max_lora_rank: int = 16
|
||||
"""Max LoRA rank."""
|
||||
max_loras: int = 1
|
||||
"""Max number of LoRAs in a single batch."""
|
||||
fully_sharded_loras: bool = False
|
||||
"""By default, only half of the LoRA computation is sharded with tensor
|
||||
parallelism. Enabling this will use the fully sharded layers. At high
|
||||
sequence length, max rank or tensor parallel size, this is likely faster.
|
||||
"""
|
||||
max_cpu_loras: Optional[int] = None
|
||||
lora_dtype: Optional[Union[torch.dtype, str]] = None
|
||||
"""Maximum number of LoRAs to store in CPU memory. Must be >= than
|
||||
`max_loras`."""
|
||||
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
lora_extra_vocab_size: int = 256
|
||||
"""Maximum size of extra vocabulary that can be present in a LoRA adapter
|
||||
(added to the base model vocabulary)."""
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
long_lora_scaling_factors: Optional[tuple[float]] = None
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = None
|
||||
"""Specify multiple scaling factors (which can be different from base model
|
||||
scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters
|
||||
trained with those scaling factors to be used at the same time. If not
|
||||
specified, only adapters trained with the base model scaling factor are
|
||||
allowed."""
|
||||
bias_enabled: bool = False
|
||||
"""Enable bias for LoRA adapters."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -2641,12 +2664,19 @@ class LoRAConfig:
|
||||
"V1 LoRA does not support long LoRA, please use V0.")
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PromptAdapterConfig:
|
||||
max_prompt_adapters: int
|
||||
max_prompt_adapter_token: int
|
||||
max_prompt_adapters: int = 1
|
||||
"""Max number of PromptAdapters in a batch."""
|
||||
max_prompt_adapter_token: int = 0
|
||||
"""Max number of PromptAdapters tokens."""
|
||||
max_cpu_prompt_adapters: Optional[int] = None
|
||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||
"""Maximum number of PromptAdapters to store in CPU memory. Must be >= than
|
||||
`max_prompt_adapters`."""
|
||||
prompt_adapter_dtype: Union[torch.dtype, str] = "auto"
|
||||
"""Data type for PromptAdapter. If auto, will default to base model dtype.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -2678,7 +2708,7 @@ class PromptAdapterConfig:
|
||||
self.max_cpu_prompt_adapters = self.max_prompt_adapters
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.prompt_adapter_dtype in (None, "auto"):
|
||||
if self.prompt_adapter_dtype == "auto":
|
||||
self.prompt_adapter_dtype = model_config.dtype
|
||||
elif isinstance(self.prompt_adapter_dtype, str):
|
||||
self.prompt_adapter_dtype = getattr(torch,
|
||||
|
||||
Reference in New Issue
Block a user