Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-17 12:19:42 +01:00
committed by GitHub
parent 99ed526101
commit d27ea94034
3 changed files with 136 additions and 81 deletions

View File

@@ -16,15 +16,15 @@ from typing_extensions import TypeIs
import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig,
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs)
ParallelConfig, PoolerConfig, PoolType,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -44,27 +44,17 @@ logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"tpu",
"xpu",
"hpu",
]
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None":
return None
try:
return cast(Callable, return_type)(val)
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
@@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.
Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
Args:
@@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict
def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
@@ -178,12 +182,14 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
@@ -199,14 +205,14 @@ class EngineArgs:
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
@@ -294,14 +300,15 @@ class EngineArgs:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
def get_kwargs(cls: type[Config]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
default = field.default
# This will only be True if default is MISSING
if field.default_factory is not MISSING:
default = field.default_factory()
kwargs[name] = {"default": default, "help": cls_docs[name]}
# Make note of if the field is optional and get the actual
@@ -331,8 +338,9 @@ class EngineArgs:
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
@@ -674,25 +682,19 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
tokenizer_group = parser.add_argument_group(
title="TokenizerPoolConfig",
description=TokenizerPoolConfig.__doc__,
)
tokenizer_group.add_argument('--tokenizer-pool-size',
**tokenizer_kwargs["pool_size"])
tokenizer_group.add_argument('--tokenizer-pool-type',
**tokenizer_kwargs["pool_type"])
tokenizer_group.add_argument('--tokenizer-pool-extra-config',
**tokenizer_kwargs["extra_config"])
# Multimodal related configs
parser.add_argument(
@@ -784,11 +786,15 @@ class EngineArgs:
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=DEVICE_OPTIONS,
help='Device type for vLLM execution.')
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
@@ -1302,8 +1308,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path