Improve configs - the rest! (#17562)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-09 23:18:44 +01:00
committed by GitHub
parent 7e3571134f
commit 4b2ed7926a
14 changed files with 456 additions and 340 deletions

View File

@@ -7,10 +7,10 @@ import json
import re
import threading
import warnings
from dataclasses import MISSING, dataclass, fields
from dataclasses import MISSING, dataclass, fields, is_dataclass
from itertools import permutations
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
Type, TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs, deprecated
@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
# yapf: enable
@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
def _parse_type(val: str) -> T:
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
@@ -62,14 +60,24 @@ def optional_type(
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _parse_type
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
return parse_type(return_type)(val)
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
return optional_type(json.loads)(val)
@deprecated(
@@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# If the field is a dataclass, we can use the model_validate_json
generator = (th for th in type_hints if is_dataclass(th))
dataclass_cls = next(generator, None)
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
if field.default is not MISSING:
default = field.default
elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build():
default = {}
else:
default = field.default_factory()
# Get the help text for the field
name = field.name
@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool):
if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method
if hasattr(dataclass_cls, "from_cli"):
from_cli = dataclass_cls.from_cli
dataclass_init = lambda x, f=from_cli: f(x)
kwargs[name]["type"] = dataclass_init
kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
@@ -771,63 +795,20 @@ class EngineArgs:
scheduler_group.add_argument("--scheduler-cls",
**scheduler_kwargs["scheduler_cls"])
# Compilation arguments
# compilation_kwargs = get_kwargs(CompilationConfig)
compilation_group = parser.add_argument_group(
title="CompilationConfig",
description=CompilationConfig.__doc__,
)
compilation_group.add_argument(
"--compilation-config",
"-O",
type=CompilationConfig.from_cli,
default=None,
help="torch.compile configuration for the model. "
"When it is a number (0, 1, 2, 3), it will be "
"interpreted as the optimization level.\n"
"NOTE: level 0 is the default level without "
"any optimization. level 1 and 2 are for internal "
"testing only. level 3 is the recommended level "
"for production.\n"
"To specify the full compilation config, "
"use a JSON string, e.g. ``{\"level\": 3, "
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
"Following the convention of traditional "
"compilers, using ``-O`` without space is also "
"supported. ``-O3`` is equivalent to ``-O 3``.")
# KVTransfer arguments
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
kv_transfer_group = parser.add_argument_group(
title="KVTransferConfig",
description=KVTransferConfig.__doc__,
)
kv_transfer_group.add_argument(
"--kv-transfer-config",
type=KVTransferConfig.from_cli,
default=None,
help="The configurations for distributed KV cache "
"transfer. Should be a JSON string.")
kv_transfer_group.add_argument(
'--kv-events-config',
type=KVEventsConfig.from_cli,
default=None,
help='The configurations for event publishing.')
# vLLM arguments
# vllm_kwargs = get_kwargs(VllmConfig)
vllm_kwargs = get_kwargs(VllmConfig)
vllm_group = parser.add_argument_group(
title="VllmConfig",
description=VllmConfig.__doc__,
)
vllm_group.add_argument(
"--additional-config",
type=json.loads,
default=None,
help="Additional config for specified platform in JSON format. "
"Different platforms may support different configs. Make sure the "
"configs are valid for the platform you are using. The input format"
" is like '{\"config_key\":\"config_value\"}'")
vllm_group.add_argument("--kv-transfer-config",
**vllm_kwargs["kv_transfer_config"])
vllm_group.add_argument('--kv-events-config',
**vllm_kwargs["kv_events_config"])
vllm_group.add_argument("--compilation-config", "-O",
**vllm_kwargs["compilation_config"])
vllm_group.add_argument("--additional-config",
**vllm_kwargs["additional_config"])
# Other arguments
parser.add_argument('--use-v2-block-manager',