Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,10 +7,10 @@ import json
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from dataclasses import MISSING, dataclass, fields, is_dataclass
|
||||
from itertools import permutations
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
|
||||
TypeVar, Union, cast, get_args, get_origin)
|
||||
from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
Type, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
@@ -36,7 +36,8 @@ from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
|
||||
is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -48,12 +49,9 @@ TypeHint = Union[type[Any], object]
|
||||
TypeHintT = Union[type[T], object]
|
||||
|
||||
|
||||
def optional_type(
|
||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||
def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]:
|
||||
|
||||
def _optional_type(val: str) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
def _parse_type(val: str) -> T:
|
||||
try:
|
||||
if return_type is json.loads and not re.match("^{.*}$", val):
|
||||
return cast(T, nullable_kvs(val))
|
||||
@@ -62,14 +60,24 @@ def optional_type(
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Value {val} cannot be converted to {return_type}.") from e
|
||||
|
||||
return _parse_type
|
||||
|
||||
|
||||
def optional_type(
|
||||
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
|
||||
|
||||
def _optional_type(val: str) -> Optional[T]:
|
||||
if val == "" or val == "None":
|
||||
return None
|
||||
return parse_type(return_type)(val)
|
||||
|
||||
return _optional_type
|
||||
|
||||
|
||||
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
|
||||
if not re.match("^{.*}$", val):
|
||||
return str(val)
|
||||
else:
|
||||
return optional_type(json.loads)(val)
|
||||
return optional_type(json.loads)(val)
|
||||
|
||||
|
||||
@deprecated(
|
||||
@@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) in {Union, Annotated}:
|
||||
type_hints.update(get_args(field.type))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
# If the field is a dataclass, we can use the model_validate_json
|
||||
generator = (th for th in type_hints if is_dataclass(th))
|
||||
dataclass_cls = next(generator, None)
|
||||
|
||||
# Get the default value of the field
|
||||
default = field.default
|
||||
if field.default_factory is not MISSING:
|
||||
default = field.default_factory()
|
||||
if field.default is not MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
if is_dataclass(field.default_factory) and is_in_doc_build():
|
||||
default = {}
|
||||
else:
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
@@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
# Initialise the kwargs dictionary for the field
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) is Union:
|
||||
type_hints.update(get_args(field.type))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
# Set other kwargs based on the type hints
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
if contains_type(type_hints, bool):
|
||||
if dataclass_cls is not None:
|
||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||
# Special case for configs with a from_cli method
|
||||
if hasattr(dataclass_cls, "from_cli"):
|
||||
from_cli = dataclass_cls.from_cli
|
||||
dataclass_init = lambda x, f=from_cli: f(x)
|
||||
kwargs[name]["type"] = dataclass_init
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif contains_type(type_hints, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
elif contains_type(type_hints, Literal):
|
||||
@@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["type"] = union_dict_and_str
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = optional_type(json.loads)
|
||||
kwargs[name]["type"] = parse_type(json.loads)
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif (contains_type(type_hints, str)
|
||||
or any(is_not_builtin(th) for th in type_hints)):
|
||||
@@ -771,63 +795,20 @@ class EngineArgs:
|
||||
scheduler_group.add_argument("--scheduler-cls",
|
||||
**scheduler_kwargs["scheduler_cls"])
|
||||
|
||||
# Compilation arguments
|
||||
# compilation_kwargs = get_kwargs(CompilationConfig)
|
||||
compilation_group = parser.add_argument_group(
|
||||
title="CompilationConfig",
|
||||
description=CompilationConfig.__doc__,
|
||||
)
|
||||
compilation_group.add_argument(
|
||||
"--compilation-config",
|
||||
"-O",
|
||||
type=CompilationConfig.from_cli,
|
||||
default=None,
|
||||
help="torch.compile configuration for the model. "
|
||||
"When it is a number (0, 1, 2, 3), it will be "
|
||||
"interpreted as the optimization level.\n"
|
||||
"NOTE: level 0 is the default level without "
|
||||
"any optimization. level 1 and 2 are for internal "
|
||||
"testing only. level 3 is the recommended level "
|
||||
"for production.\n"
|
||||
"To specify the full compilation config, "
|
||||
"use a JSON string, e.g. ``{\"level\": 3, "
|
||||
"\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n"
|
||||
"Following the convention of traditional "
|
||||
"compilers, using ``-O`` without space is also "
|
||||
"supported. ``-O3`` is equivalent to ``-O 3``.")
|
||||
|
||||
# KVTransfer arguments
|
||||
# kv_transfer_kwargs = get_kwargs(KVTransferConfig)
|
||||
kv_transfer_group = parser.add_argument_group(
|
||||
title="KVTransferConfig",
|
||||
description=KVTransferConfig.__doc__,
|
||||
)
|
||||
kv_transfer_group.add_argument(
|
||||
"--kv-transfer-config",
|
||||
type=KVTransferConfig.from_cli,
|
||||
default=None,
|
||||
help="The configurations for distributed KV cache "
|
||||
"transfer. Should be a JSON string.")
|
||||
kv_transfer_group.add_argument(
|
||||
'--kv-events-config',
|
||||
type=KVEventsConfig.from_cli,
|
||||
default=None,
|
||||
help='The configurations for event publishing.')
|
||||
|
||||
# vLLM arguments
|
||||
# vllm_kwargs = get_kwargs(VllmConfig)
|
||||
vllm_kwargs = get_kwargs(VllmConfig)
|
||||
vllm_group = parser.add_argument_group(
|
||||
title="VllmConfig",
|
||||
description=VllmConfig.__doc__,
|
||||
)
|
||||
vllm_group.add_argument(
|
||||
"--additional-config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Additional config for specified platform in JSON format. "
|
||||
"Different platforms may support different configs. Make sure the "
|
||||
"configs are valid for the platform you are using. The input format"
|
||||
" is like '{\"config_key\":\"config_value\"}'")
|
||||
vllm_group.add_argument("--kv-transfer-config",
|
||||
**vllm_kwargs["kv_transfer_config"])
|
||||
vllm_group.add_argument('--kv-events-config',
|
||||
**vllm_kwargs["kv_events_config"])
|
||||
vllm_group.add_argument("--compilation-config", "-O",
|
||||
**vllm_kwargs["compilation_config"])
|
||||
vllm_group.add_argument("--additional-config",
|
||||
**vllm_kwargs["additional_config"])
|
||||
|
||||
# Other arguments
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
|
||||
Reference in New Issue
Block a user