Turn @config into a dataclass_transform (#31541)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,14 +10,17 @@ import json
|
||||
import pathlib
|
||||
import textwrap
|
||||
from collections.abc import Callable, Mapping, Sequence, Set
|
||||
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
||||
from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.fields import Field as PydanticField
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
from typing_extensions import dataclass_transform, runtime_checkable
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -29,23 +32,39 @@ else:
|
||||
DataclassInstance = Any
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
ConfigT = TypeVar("ConfigT", bound=DataclassInstance)
|
||||
|
||||
|
||||
def config(cls: ConfigT) -> ConfigT:
|
||||
"""
|
||||
A decorator that ensures all fields in a dataclass have default values
|
||||
and that each field has a docstring.
|
||||
@dataclass_transform(field_specifiers=(PydanticField,))
|
||||
def config(
|
||||
cls: type[ConfigT] | None = None,
|
||||
*,
|
||||
config: ConfigDict | None = None,
|
||||
**kwargs: Any,
|
||||
) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]:
|
||||
"""Decorator to create a pydantic dataclass with default config. The default config
|
||||
for the dataclass forbids extra fields.
|
||||
|
||||
If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument
|
||||
provided by `get_kwargs` will be
|
||||
`pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the
|
||||
`cli_arg` as a JSON string which gets validated by `pydantic`.
|
||||
All config classes in vLLM should use this decorator.
|
||||
|
||||
Config validation is performed by the tools/pre_commit/validate_config.py
|
||||
script, which is invoked during the pre-commit checks.
|
||||
"""
|
||||
return cls
|
||||
Args:
|
||||
cls: The class to decorate
|
||||
config: The pydantic ConfigDict to use. If provided, it will be merged with
|
||||
the default config.
|
||||
**kwargs: Additional arguments to pass to pydantic.dataclass."""
|
||||
# Extra fields are forbidden by default
|
||||
merged_config = ConfigDict(extra="forbid")
|
||||
if config is not None:
|
||||
merged_config.update(config)
|
||||
|
||||
def decorator(cls):
|
||||
return dataclass(cls, config=merged_config, **kwargs)
|
||||
|
||||
# Called with arguments: @config(config=...)
|
||||
if cls is None:
|
||||
return decorator
|
||||
# Called without arguments: @config
|
||||
return decorator(cls)
|
||||
|
||||
|
||||
def get_field(cls: ConfigType, name: str) -> Field:
|
||||
@@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
default factory fields in `EngineArgs`."""
|
||||
if not is_dataclass(cls):
|
||||
raise TypeError("The given class is not a dataclass.")
|
||||
cls_fields = {f.name: f for f in fields(cls)}
|
||||
if name not in cls_fields:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
|
||||
named_field: Field = cls_fields[name]
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
if isinstance(default, FieldInfo):
|
||||
# Handle pydantic.Field defaults
|
||||
if default.default_factory is not None:
|
||||
return field(default_factory=default.default_factory)
|
||||
else:
|
||||
default = default.default
|
||||
return field(default=default)
|
||||
try:
|
||||
named_field = next(f for f in fields(cls) if f.name == name)
|
||||
except StopIteration as e:
|
||||
raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory."
|
||||
)
|
||||
# The arguments to copy to the new field
|
||||
default = named_field.default
|
||||
default_factory = named_field.default_factory
|
||||
init = named_field.init
|
||||
|
||||
# Handle pydantic.Field
|
||||
if isinstance(default, FieldInfo):
|
||||
if default.init is not None:
|
||||
init = default.init
|
||||
if default.default_factory is not None:
|
||||
default_factory = cast(Callable[[], Any], default.default_factory)
|
||||
default = MISSING
|
||||
else:
|
||||
default = default.default
|
||||
|
||||
if default is MISSING and default_factory is MISSING:
|
||||
logger.warning_once(
|
||||
"%s.%s has no default or default factory.", cls.__name__, name
|
||||
)
|
||||
return field(default=default, default_factory=default_factory, init=init)
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return get_field(cls, name).init
|
||||
|
||||
|
||||
def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT:
|
||||
"""Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace),
|
||||
but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead
|
||||
of `dataclasses.field`"""
|
||||
cls = type(dataclass_instance)
|
||||
dataclass_dict = dataclass_instance.__dict__
|
||||
dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)}
|
||||
dataclass_dict.update(kwargs)
|
||||
return cls(**dataclass_dict)
|
||||
|
||||
|
||||
def getattr_iter(
|
||||
@@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
return out
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
def compute_hash(self) -> str: ...
|
||||
|
||||
Reference in New Issue
Block a user