Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import json
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Type,
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
@@ -192,18 +192,23 @@ class EngineArgs:
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
max_loras: int = LoRAConfig.max_loras
|
||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
|
||||
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
|
||||
LoRAConfig.long_lora_scaling_factors
|
||||
# PromptAdapter fields
|
||||
enable_prompt_adapter: bool = False
|
||||
max_prompt_adapters: int = 1
|
||||
max_prompt_adapter_token: int = 0
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
|
||||
max_prompt_adapter_token: int = \
|
||||
PromptAdapterConfig.max_prompt_adapter_token
|
||||
|
||||
device: Device = DeviceConfig.device
|
||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||
@@ -338,10 +343,21 @@ class EngineArgs:
|
||||
kwargs[name]["choices"] = choices
|
||||
choice_type = type(choices[0])
|
||||
assert all(type(c) is choice_type for c in choices), (
|
||||
f"All choices must be of the same type. "
|
||||
"All choices must be of the same type. "
|
||||
f"Got {choices} with types {[type(c) for c in choices]}"
|
||||
)
|
||||
kwargs[name]["type"] = choice_type
|
||||
elif can_be_type(field_type, tuple):
|
||||
if is_type_in_union(field_type, tuple):
|
||||
field_type = get_type_from_union(field_type, tuple)
|
||||
dtypes = get_args(field_type)
|
||||
dtype = dtypes[0]
|
||||
assert all(
|
||||
d is dtype for d in dtypes if d is not Ellipsis
|
||||
), ("All non-Ellipsis tuple elements must be of the same "
|
||||
f"type. Got {dtypes}.")
|
||||
kwargs[name]["type"] = dtype
|
||||
kwargs[name]["nargs"] = "+"
|
||||
elif can_be_type(field_type, int):
|
||||
kwargs[name]["type"] = optional_int if optional else int
|
||||
elif can_be_type(field_type, float):
|
||||
@@ -685,70 +701,49 @@ class EngineArgs:
|
||||
'inputs.')
|
||||
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
parser.add_argument('--enable-lora-bias',
|
||||
action='store_true',
|
||||
help='If True, enable bias for LoRA adapters.')
|
||||
parser.add_argument('--max-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_loras,
|
||||
help='Max number of LoRAs in a single batch.')
|
||||
parser.add_argument('--max-lora-rank',
|
||||
type=int,
|
||||
default=EngineArgs.max_lora_rank,
|
||||
help='Max LoRA rank.')
|
||||
parser.add_argument(
|
||||
'--lora-extra-vocab-size',
|
||||
type=int,
|
||||
default=EngineArgs.lora_extra_vocab_size,
|
||||
help=('Maximum size of extra vocabulary that can be '
|
||||
'present in a LoRA adapter (added to the base '
|
||||
'model vocabulary).'))
|
||||
parser.add_argument(
|
||||
lora_kwargs = get_kwargs(LoRAConfig)
|
||||
lora_group = parser.add_argument_group(
|
||||
title="LoRAConfig",
|
||||
description=LoRAConfig.__doc__,
|
||||
)
|
||||
lora_group.add_argument(
|
||||
'--enable-lora',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
lora_group.add_argument('--enable-lora-bias',
|
||||
**lora_kwargs["bias_enabled"])
|
||||
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
|
||||
lora_group.add_argument('--max-lora-rank',
|
||||
**lora_kwargs["max_lora_rank"])
|
||||
lora_group.add_argument('--lora-extra-vocab-size',
|
||||
**lora_kwargs["lora_extra_vocab_size"])
|
||||
lora_group.add_argument(
|
||||
'--lora-dtype',
|
||||
type=str,
|
||||
default=EngineArgs.lora_dtype,
|
||||
choices=['auto', 'float16', 'bfloat16'],
|
||||
help=('Data type for LoRA. If auto, will default to '
|
||||
'base model dtype.'))
|
||||
parser.add_argument(
|
||||
'--long-lora-scaling-factors',
|
||||
type=optional_str,
|
||||
default=EngineArgs.long_lora_scaling_factors,
|
||||
help=('Specify multiple scaling factors (which can '
|
||||
'be different from base model scaling factor '
|
||||
'- see eg. Long LoRA) to allow for multiple '
|
||||
'LoRA adapters trained with those scaling '
|
||||
'factors to be used at the same time. If not '
|
||||
'specified, only adapters trained with the '
|
||||
'base model scaling factor are allowed.'))
|
||||
parser.add_argument(
|
||||
'--max-cpu-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_cpu_loras,
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_loras.'))
|
||||
parser.add_argument(
|
||||
'--fully-sharded-loras',
|
||||
action='store_true',
|
||||
help=('By default, only half of the LoRA computation is '
|
||||
'sharded with tensor parallelism. '
|
||||
'Enabling this will use the fully sharded layers. '
|
||||
'At high sequence length, max rank or '
|
||||
'tensor parallel size, this is likely faster.'))
|
||||
parser.add_argument('--enable-prompt-adapter',
|
||||
action='store_true',
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
parser.add_argument('--max-prompt-adapters',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapters,
|
||||
help='Max number of PromptAdapters in a batch.')
|
||||
parser.add_argument('--max-prompt-adapter-token',
|
||||
type=int,
|
||||
default=EngineArgs.max_prompt_adapter_token,
|
||||
help='Max number of PromptAdapters tokens')
|
||||
**lora_kwargs["lora_dtype"],
|
||||
)
|
||||
lora_group.add_argument('--long-lora-scaling-factors',
|
||||
**lora_kwargs["long_lora_scaling_factors"])
|
||||
lora_group.add_argument('--max-cpu-loras',
|
||||
**lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument('--fully-sharded-loras',
|
||||
**lora_kwargs["fully_sharded_loras"])
|
||||
|
||||
# PromptAdapter related configs
|
||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||
prompt_adapter_group = parser.add_argument_group(
|
||||
title="PromptAdapterConfig",
|
||||
description=PromptAdapterConfig.__doc__,
|
||||
)
|
||||
prompt_adapter_group.add_argument(
|
||||
'--enable-prompt-adapter',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help='If True, enable handling of PromptAdapters.')
|
||||
prompt_adapter_group.add_argument(
|
||||
'--max-prompt-adapters',
|
||||
**prompt_adapter_kwargs["max_prompt_adapters"])
|
||||
prompt_adapter_group.add_argument(
|
||||
'--max-prompt-adapter-token',
|
||||
**prompt_adapter_kwargs["max_prompt_adapter_token"])
|
||||
|
||||
# Device arguments
|
||||
device_kwargs = get_kwargs(DeviceConfig)
|
||||
|
||||
Reference in New Issue
Block a user