Enable Pydantic mypy checks and convert configs to Pydantic dataclasses (#17599)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import SkipValidation, TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -38,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_doc_build, is_in_ray_actor)
|
||||
GiB_bytes, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -156,7 +157,8 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) in {Union, Annotated}:
|
||||
type_hints.update(get_args(field.type))
|
||||
predicate = lambda arg: not isinstance(arg, SkipValidation)
|
||||
type_hints.update(filter(predicate, get_args(field.type)))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
@@ -168,10 +170,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
if field.default is not MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
if is_dataclass(field.default_factory) and is_in_doc_build():
|
||||
default = {}
|
||||
else:
|
||||
default = field.default_factory()
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
@@ -189,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
||||
if dataclass_cls is not None:
|
||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||
# Special case for configs with a from_cli method
|
||||
if hasattr(dataclass_cls, "from_cli"):
|
||||
from_cli = dataclass_cls.from_cli
|
||||
dataclass_init = lambda x, f=from_cli: f(x)
|
||||
kwargs[name]["type"] = dataclass_init
|
||||
|
||||
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
|
||||
try:
|
||||
if hasattr(cls, "from_cli"):
|
||||
return cls.from_cli(val)
|
||||
return TypeAdapter(cls).validate_json(val)
|
||||
except ValidationError as e:
|
||||
raise argparse.ArgumentTypeError(repr(e)) from e
|
||||
|
||||
kwargs[name]["type"] = parse_dataclass
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif contains_type(type_hints, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
@@ -225,12 +228,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif contains_type(type_hints,
|
||||
dict) and (contains_type(type_hints, str) or any(
|
||||
is_not_builtin(th) for th in type_hints)):
|
||||
elif (contains_type(type_hints, dict)
|
||||
and (contains_type(type_hints, str)
|
||||
or any(is_not_builtin(th) for th in type_hints))):
|
||||
kwargs[name]["type"] = union_dict_and_str
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = parse_type(json.loads)
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif (contains_type(type_hints, str)
|
||||
@@ -317,8 +319,7 @@ class EngineArgs:
|
||||
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
||||
rope_theta: Optional[float] = ModelConfig.rope_theta
|
||||
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
|
||||
hf_overrides: Optional[HfOverrides] = \
|
||||
get_field(ModelConfig, "hf_overrides")
|
||||
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
|
||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
@@ -398,7 +399,8 @@ class EngineArgs:
|
||||
get_field(ModelConfig, "override_neuron_config")
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
compilation_config: CompilationConfig = \
|
||||
get_field(VllmConfig, "compilation_config")
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
@@ -413,7 +415,8 @@ class EngineArgs:
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
additional_config: dict[str, Any] = \
|
||||
get_field(VllmConfig, "additional_config")
|
||||
enable_reasoning: Optional[bool] = None # DEPRECATED
|
||||
reasoning_parser: str = DecodingConfig.reasoning_backend
|
||||
|
||||
|
||||
Reference in New Issue
Block a user