Turn @config into a dataclass_transform (#31541)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-02-03 17:40:59 +00:00
committed by GitHub
parent b1bb18de8d
commit 61e632aea1
32 changed files with 153 additions and 191 deletions

View File

@@ -10,14 +10,17 @@ import json
import pathlib
import textwrap
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from dataclasses import MISSING, Field, field, fields, is_dataclass
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
import regex as re
import torch
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass
from pydantic.fields import Field as PydanticField
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
from typing_extensions import dataclass_transform, runtime_checkable
from vllm.logger import init_logger
@@ -29,23 +32,39 @@ else:
DataclassInstance = Any
ConfigType = type[DataclassInstance]
ConfigT = TypeVar("ConfigT", bound=ConfigType)
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
def config(cls: ConfigT) -> ConfigT:
"""
A decorator that ensures all fields in a dataclass have default values
and that each field has a docstring.
@dataclass_transform(field_specifiers=(PydanticField,))
def config(
cls: type[ConfigT] | None = None,
*,
config: ConfigDict | None = None,
**kwargs: Any,
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
"""Decorator to create a pydantic dataclass with default config. The default config
for the dataclass forbids extra fields.
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
provided by `get_kwargs` will be
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
`cli_arg` as a JSON string which gets validated by `pydantic`.
All config classes in vLLM should use this decorator.
Config validation is performed by the tools/pre_commit/validate_config.py
script, which is invoked during the pre-commit checks.
"""
return cls
Args:
cls: The class to decorate
config: The pydantic ConfigDict to use. If provided, it will be merged with
the default config.
**kwargs: Additional arguments to pass to pydantic.dataclass."""
# Extra fields are forbidden by default
merged_config = ConfigDict(extra="forbid")
if config is not None:
merged_config.update(config)
def decorator(cls):
return dataclass(cls, config=merged_config, **kwargs)
# Called with arguments: @config(config=...)
if cls is None:
return decorator
# Called without arguments: @config
return decorator(cls)
def get_field(cls: ConfigType, name: str) -> Field:
@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields[name]
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
if isinstance(default, FieldInfo):
# Handle pydantic.Field defaults
if default.default_factory is not None:
return field(default_factory=default.default_factory)
else:
default = default.default
return field(default=default)
try:
named_field = next(f for f in fields(cls) if f.name == name)
except StopIteration as e:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory."
)
# The arguments to copy to the new field
default = named_field.default
default_factory = named_field.default_factory
init = named_field.init
# Handle pydantic.Field
if isinstance(default, FieldInfo):
if default.init is not None:
init = default.init
if default.default_factory is not None:
default_factory = cast(Callable[[], Any], default.default_factory)
default = MISSING
else:
default = default.default
if default is MISSING and default_factory is MISSING:
logger.warning_once(
"%s.%s has no default or default factory.", cls.__name__, name
)
return field(default=default, default_factory=default_factory, init=init)
def is_init_field(cls: ConfigType, name: str) -> bool:
return get_field(cls, name).init
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
of `dataclasses.field`"""
cls = type(dataclass_instance)
dataclass_dict = dataclass_instance.__dict__
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
dataclass_dict.update(kwargs)
return cls(**dataclass_dict)
def getattr_iter(
@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
return out
def is_init_field(cls: ConfigType, name: str) -> bool:
return next(f for f in fields(cls) if f.name == name).init
@runtime_checkable
class SupportsHash(Protocol):
def compute_hash(self) -> str: ...