Enable Pydantic mypy checks and convert configs to Pydantic dataclasses (#17599)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-28 13:46:04 +01:00
committed by GitHub
parent d781930f90
commit 4c2b38ce9e
11 changed files with 115 additions and 102 deletions

View File

@@ -11,8 +11,8 @@ import uuid
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
is_dataclass, replace)
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from importlib.util import find_spec
from pathlib import Path
@@ -21,9 +21,12 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
import regex as re
import torch
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated
from typing_extensions import deprecated, runtime_checkable
import vllm.envs as envs
from vllm import version
@@ -57,10 +60,15 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
ConfigType = type[DataclassInstance]
else:
PlacementGroup = Any
ExecutorBase = Any
QuantizationConfig = Any
BaseModelLoader = Any
TensorizerConfig = Any
ConfigType = type
logger = init_logger(__name__)
@@ -92,6 +100,7 @@ HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str:
@@ -223,7 +232,7 @@ ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig:
"""Configuration for the model."""
@@ -236,7 +245,7 @@ class ModelConfig:
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
tokenizer: str = None # type: ignore
tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
@@ -284,7 +293,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
@@ -602,6 +611,22 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
@field_validator("quantization", mode="before")
@classmethod
def validate_quantization_before(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower()
return value
@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
if not isinstance(self.tokenizer, str):
raise ValueError("tokenizer must be a string after __post_init__.")
if not isinstance(self.max_model_len, int):
raise ValueError(
"max_model_len must be an integer after __post_init__.")
return self
@property
def registry(self):
return ModelRegistry
@@ -823,8 +848,7 @@ class ModelConfig:
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = cast(QuantizationMethods,
self.quantization.lower())
self.quantization = cast(QuantizationMethods, self.quantization)
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
@@ -1397,7 +1421,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: BlockSize = None # type: ignore
block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
@@ -1619,7 +1643,8 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: dict = field(default_factory=dict)
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format."""
ignore_patterns: Optional[Union[list[str], str]] = None
@@ -1929,19 +1954,19 @@ class SchedulerConfig:
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: int = None # type: ignore
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: int = None # type: ignore
max_num_seqs: SkipValidation[int] = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: int = None # type: ignore
max_model_len: SkipValidation[int] = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
@@ -1980,7 +2005,7 @@ class SchedulerConfig:
"""Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt."""
enable_chunked_prefill: bool = None # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
@@ -2202,7 +2227,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
@@ -2260,8 +2285,8 @@ class DeviceConfig:
self.device = torch.device(self.device_type)
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model", "deepseek_mtp"]
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]
@@ -2272,8 +2297,7 @@ class SpeculativeConfig:
"""Configuration for speculative decoding."""
# General speculative decoding control
num_speculative_tokens: int = field(default=None,
init=True) # type: ignore
num_speculative_tokens: SkipValidation[int] = None # type: ignore
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None
@@ -2349,26 +2373,23 @@ class SpeculativeConfig:
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
target_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None,
init=True) # type: ignore
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore
disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative
decoding."""
# params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
"""The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None,
init=True) # type: ignore
draft_parallel_config: SkipValidation[
ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
def compute_hash(self) -> str:
@@ -2766,7 +2787,7 @@ LoRADType = Literal["auto", "float16", "bfloat16"]
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig:
"""Configuration for LoRA."""
@@ -2863,7 +2884,7 @@ class LoRAConfig:
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig:
"""Configuration for PromptAdapters."""
@@ -3892,17 +3913,11 @@ class CompilationConfig:
"pass_config",
"traced_files",
}
include = dict()
for k, v in asdict(self).items():
if k in exclude:
continue
f = get_field(CompilationConfig, k)
if (d := f.default) is not MISSING and d == v:
continue
if (df := f.default_factory) is not MISSING and df() == v:
continue
include[k] = v
return json.dumps(include)
# The cast to string is necessary because Pydantic is mocked in docs
# builds and sphinx-argparse doesn't know the return type of decode()
return str(
TypeAdapter(CompilationConfig).dump_json(
self, exclude=exclude, exclude_unset=True).decode())
__str__ = __repr__
@@ -3911,7 +3926,7 @@ class CompilationConfig:
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return cls(**json.loads(cli_value))
return TypeAdapter(CompilationConfig).validate_json(cli_value)
def __post_init__(self) -> None:
count_none = self.custom_ops.count("none")
@@ -4037,7 +4052,7 @@ class CompilationConfig:
@config
@dataclass
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase.
@@ -4294,9 +4309,6 @@ class VllmConfig:
"To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.")
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp: