From 61e632aea15f76fd1c46354b00f9cac62cd28c4e Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:40:59 +0000 Subject: [PATCH] Turn `@config` into a `dataclass_transform` (#31541) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/engine/test_arg_utils.py | 26 ++-- tests/test_config.py | 5 +- tests/tools/test_config_validator.py | 26 ++-- tests/v1/e2e/test_spec_decode.py | 2 +- .../test_backend_guidance.py | 7 +- tools/pre_commit/validate_config.py | 26 ++-- vllm/config/__init__.py | 2 + vllm/config/attention.py | 2 - vllm/config/cache.py | 2 - vllm/config/compilation.py | 6 +- vllm/config/device.py | 4 +- vllm/config/ec_transfer.py | 3 - vllm/config/kv_events.py | 2 - vllm/config/kv_transfer.py | 3 - vllm/config/load.py | 2 - vllm/config/lora.py | 4 +- vllm/config/model.py | 4 +- vllm/config/multimodal.py | 1 - vllm/config/observability.py | 2 - vllm/config/parallel.py | 7 -- vllm/config/pooler.py | 3 - vllm/config/profiler.py | 2 - vllm/config/scheduler.py | 2 - vllm/config/speculative.py | 2 - vllm/config/speech_to_text.py | 3 - vllm/config/structured_outputs.py | 2 - vllm/config/utils.py | 111 ++++++++++++------ vllm/config/vllm.py | 19 +-- vllm/entrypoints/openai/cli_args.py | 3 - .../configs/speculators/algos.py | 17 +-- .../configs/speculators/base.py | 34 +++--- vllm/v1/spec_decode/draft_model.py | 10 +- 32 files changed, 153 insertions(+), 191 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 2acb38bc9..d1986e0a4 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -3,11 +3,11 @@ import json from argparse import ArgumentError -from contextlib import nullcontext -from dataclasses import dataclass, field +from contextlib import AbstractContextManager, nullcontext from typing import Annotated, Literal import pytest +from pydantic import Field from vllm.config import AttentionConfig, CompilationConfig, config from vllm.engine.arg_utils import ( @@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected): ], ) def test_literal_to_kwargs(type_hints, expected): - context = nullcontext() + context: AbstractContextManager[object] = nullcontext() if expected is Exception: context = pytest.raises(expected) with context: @@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected): @config -@dataclass class NestedConfig: field: int = 1 """field""" @config -@dataclass class DummyConfig: regular_bool: bool = True """Regular bool with default True""" @@ -119,23 +117,23 @@ class DummyConfig: """Optional bool with default None""" optional_literal: Literal["x", "y"] | None = None """Optional literal with default None""" - tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) + tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3)) """Tuple with variable length""" - tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) + tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2)) """Tuple with fixed length""" - list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) + list_n: list[int] = Field(default_factory=lambda: [1, 2, 3]) """List with variable length""" - list_literal: list[Literal[1, 2]] = field(default_factory=list) + list_literal: list[Literal[1, 2]] = Field(default_factory=list) """List with literal choices""" - list_union: list[str | type[object]] = field(default_factory=list) + list_union: list[str | type[object]] = Field(default_factory=list) """List with union type""" - set_n: set[int] = field(default_factory=lambda: {1, 2, 3}) + set_n: set[int] = Field(default_factory=lambda: {1, 2, 3}) """Set with variable length""" literal_literal: Literal[Literal[1], Literal[2]] = 1 """Literal of literals with default 1""" - json_tip: dict = field(default_factory=dict) + json_tip: dict = Field(default_factory=dict) """Dict which will be JSON in CLI""" - nested_config: NestedConfig = field(default_factory=NestedConfig) + nested_config: NestedConfig = Field(default_factory=NestedConfig) """Nested config""" @@ -195,7 +193,7 @@ def test_get_kwargs(): json_tip = "Should either be a valid JSON string or JSON keys" assert json_tip in kwargs["json_tip"]["help"] # nested config should construct the nested config - assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) + assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg] @pytest.mark.parametrize( diff --git a/tests/test_config.py b/tests/test_config.py index 1676598b1..f3c3003a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -66,9 +66,6 @@ class _TestConfigFields: def test_get_field(): - with pytest.raises(ValueError): - get_field(_TestConfigFields, "a") - b = get_field(_TestConfigFields, "b") assert isinstance(b, Field) assert b.default is MISSING @@ -188,7 +185,7 @@ def test_get_pooling_config(): ) def test_get_pooling_config_from_args(): model_id = "sentence-transformers/all-MiniLM-L12-v2" - pooler_config = PoolerConfig(seq_pooling_type="CLS", normalize=True) + pooler_config = PoolerConfig(seq_pooling_type="CLS", use_activation=False) model_config = ModelConfig(model_id, pooler_config=pooler_config) assert asdict(model_config.pooler_config) == asdict(pooler_config) diff --git a/tests/tools/test_config_validator.py b/tests/tools/test_config_validator.py index d6104dc6d..e317bf911 100644 --- a/tests/tools/test_config_validator.py +++ b/tests/tools/test_config_validator.py @@ -7,31 +7,22 @@ import pytest from tools.pre_commit.validate_config import validate_ast -_TestConfig1 = """ +_TestConfig1 = ''' @config class _TestConfig1: - pass -""" - -_TestConfig2 = ''' -@config -@dataclass -class _TestConfig2: a: int """docstring""" ''' -_TestConfig3 = """ +_TestConfig2 = """ @config -@dataclass -class _TestConfig3: +class _TestConfig2: a: int = 1 """ -_TestConfig4 = ''' +_TestConfig3 = ''' @config -@dataclass -class _TestConfig4: +class _TestConfig3: a: Union[Literal[1], Literal[2]] = 1 """docstring""" ''' @@ -40,10 +31,9 @@ class _TestConfig4: @pytest.mark.parametrize( ("test_config", "expected_error"), [ - (_TestConfig1, "must be a dataclass"), - (_TestConfig2, "must have a default"), - (_TestConfig3, "must have a docstring"), - (_TestConfig4, "must use a single Literal"), + (_TestConfig1, "must have a default"), + (_TestConfig2, "must have a docstring"), + (_TestConfig3, "must use a single Literal"), ], ) def test_config(test_config, expected_error): diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 02e152914..4905a4120 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -766,8 +766,8 @@ def assert_draft_model_correctness(args: ArgsTest, enforce_eager: bool): "max_model_len": args.max_model_len, "enforce_eager": enforce_eager, "draft_tensor_parallel_size": args.draft_tensor_parallel_size, - "max_num_seqs": 100, # limit cudagraph capture runtime }, + max_num_seqs=100, # limit cudagraph capture runtime max_model_len=args.max_model_len, gpu_memory_utilization=args.gpu_memory_utilization, tensor_parallel_size=args.target_tensor_parallel_size, diff --git a/tests/v1/structured_output/test_backend_guidance.py b/tests/v1/structured_output/test_backend_guidance.py index 4c01560fc..362f75c49 100644 --- a/tests/v1/structured_output/test_backend_guidance.py +++ b/tests/v1/structured_output/test_backend_guidance.py @@ -26,11 +26,8 @@ def test_backend_guidance_rollback_terminated(): # guidance backend. In that case we are in a stopped state, but # it should be reverted in case EOS is not accepted by the target # model. - vllm_config = VllmConfig( - decoding_config=StructuredOutputsConfig( - backend="guidance", - ) - ) + structured_outputs_config = StructuredOutputsConfig(backend="guidance") + vllm_config = VllmConfig(structured_outputs_config=structured_outputs_config) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) backend = GuidanceBackend( diff --git a/tools/pre_commit/validate_config.py b/tools/pre_commit/validate_config.py index fb6f0e6a9..7da32bc6b 100644 --- a/tools/pre_commit/validate_config.py +++ b/tools/pre_commit/validate_config.py @@ -54,24 +54,18 @@ class ConfigValidator(ast.NodeVisitor): def __init__(self): ... def visit_ClassDef(self, node): - # Validate class with both @config and @dataclass decorators - decorators = [ - id - for d in node.decorator_list - if ( - isinstance(d, ast.Name) - and ((id := d.id) == "config" or id == "dataclass") - ) - or ( - isinstance(d, ast.Call) - and (isinstance(d.func, ast.Name) and (id := d.func.id) == "dataclass") - ) - ] + # Validate classes with a @config decorator + decorators = set() + for decorator in node.decorator_list: + if isinstance(decorator, ast.Call): + decorator = decorator.func + if isinstance(decorator, ast.Name) and decorator.id == "config": + decorators.add(decorator.id) - if set(decorators) == {"config", "dataclass"}: + if decorators == {"config"}: validate_class(node) - elif set(decorators) == {"config"}: - fail(f"Class {node.name} with config decorator must be a dataclass.", node) + elif "config" in decorators: + fail(f"config decorator for {node.name} should be used alone", node) self.generic_visit(node) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7f6565053..b2044c6e1 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -36,6 +36,7 @@ from vllm.config.utils import ( config, get_attr_docs, is_init_field, + replace, update_config, ) from vllm.config.vllm import ( @@ -101,6 +102,7 @@ __all__ = [ "config", "get_attr_docs", "is_init_field", + "replace", "update_config", # From vllm.config.vllm "VllmConfig", diff --git a/vllm/config/attention.py b/vllm/config/attention.py index ee072fb1c..9379b2878 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -4,14 +4,12 @@ from typing import Any, Literal from pydantic import field_validator -from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.v1.attention.backends.registry import AttentionBackendEnum @config -@dataclass class AttentionConfig: """Configuration for attention mechanisms in vLLM.""" diff --git a/vllm/config/cache.py b/vllm/config/cache.py index abf10e21d..bf121e544 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -6,7 +6,6 @@ from dataclasses import field from typing import TYPE_CHECKING, Any, Literal from pydantic import Field, SkipValidation, field_validator -from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.logger import init_logger @@ -37,7 +36,6 @@ KVOffloadingBackend = Literal["native", "lmcache"] @config -@dataclass class CacheConfig: """Configuration for the KV cache.""" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7a69629f7..556254a65 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -8,8 +8,7 @@ from dataclasses import field from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, Literal -from pydantic import ConfigDict, Field, TypeAdapter, field_validator -from pydantic.dataclasses import dataclass +from pydantic import Field, TypeAdapter, field_validator import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass @@ -96,7 +95,6 @@ class CUDAGraphMode(enum.Enum): @config -@dataclass(config=ConfigDict(extra="forbid")) class PassConfig: """Configuration for custom Inductor passes. @@ -267,7 +265,6 @@ class DynamicShapesType(str, enum.Enum): @config -@dataclass(config=ConfigDict(extra="forbid")) class DynamicShapesConfig: """Configuration to control/debug torch compile dynamic shapes.""" @@ -311,7 +308,6 @@ class DynamicShapesConfig: @config -@dataclass(config=ConfigDict(extra="forbid")) class CompilationConfig: """Configuration for compilation. diff --git a/vllm/config/device.py b/vllm/config/device.py index 85662ddff..c20e4d0f2 100644 --- a/vllm/config/device.py +++ b/vllm/config/device.py @@ -6,7 +6,6 @@ from typing import Any, Literal import torch from pydantic import ConfigDict, SkipValidation -from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.utils.hashing import safe_hash @@ -14,8 +13,7 @@ from vllm.utils.hashing import safe_hash Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"] -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +@config(config=ConfigDict(arbitrary_types_allowed=True)) class DeviceConfig: """Configuration for the device to use for vLLM execution.""" diff --git a/vllm/config/ec_transfer.py b/vllm/config/ec_transfer.py index d95236f81..c7f56557f 100644 --- a/vllm/config/ec_transfer.py +++ b/vllm/config/ec_transfer.py @@ -5,8 +5,6 @@ import uuid from dataclasses import field from typing import Any, Literal, get_args -from pydantic.dataclasses import dataclass - from vllm.config.utils import config ECProducer = Literal["ec_producer"] @@ -15,7 +13,6 @@ ECRole = Literal[ECProducer, ECConsumer] @config -@dataclass class ECTransferConfig: """Configuration for distributed EC cache transfer.""" diff --git a/vllm/config/kv_events.py b/vllm/config/kv_events.py index ce46cc03c..94da54c78 100644 --- a/vllm/config/kv_events.py +++ b/vllm/config/kv_events.py @@ -5,13 +5,11 @@ from typing import Literal from pydantic import Field -from pydantic.dataclasses import dataclass from vllm.config.utils import config @config -@dataclass class KVEventsConfig: """Configuration for KV event publishing.""" diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index 98cea821c..fe3b218fb 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -5,8 +5,6 @@ import uuid from dataclasses import field from typing import Any, Literal, get_args -from pydantic.dataclasses import dataclass - from vllm.config.utils import config from vllm.utils.hashing import safe_hash @@ -16,7 +14,6 @@ KVRole = Literal[KVProducer, KVConsumer] @config -@dataclass class KVTransferConfig: """Configuration for distributed KV cache transfer.""" diff --git a/vllm/config/load.py b/vllm/config/load.py index 579a0bc31..64a269e98 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Any from pydantic import Field, field_validator -from pydantic.dataclasses import dataclass from vllm.config.utils import config from vllm.logger import init_logger @@ -21,7 +20,6 @@ logger = init_logger(__name__) @config -@dataclass class LoadConfig: """Configuration for loading the model weights.""" diff --git a/vllm/config/lora.py b/vllm/config/lora.py index f15beffe1..0d310c87e 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Literal import torch from pydantic import ConfigDict, Field, model_validator -from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.utils import config @@ -26,8 +25,7 @@ MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512] LoRAExtraVocabSize = Literal[256, 512] -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +@config(config=ConfigDict(arbitrary_types_allowed=True)) class LoRAConfig: """Configuration for LoRA.""" diff --git a/vllm/config/model.py b/vllm/config/model.py index 48ff44ac9..3bb8e7177 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch from pydantic import ConfigDict, Field, field_validator, model_validator -from pydantic.dataclasses import dataclass import vllm.envs as envs from vllm.config.model_arch import ( @@ -97,8 +96,7 @@ AttnTypeStr = Literal[ ] -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +@config(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: """Configuration for the model.""" diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index f4e834f64..48eea6f4e 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -51,7 +51,6 @@ DummyOptions: TypeAlias = ( @config -@dataclass class MultiModalConfig: """Controls the behavior of multimodal models.""" diff --git a/vllm/config/observability.py b/vllm/config/observability.py index 9700c9117..387175912 100644 --- a/vllm/config/observability.py +++ b/vllm/config/observability.py @@ -6,7 +6,6 @@ from typing import Any, Literal, cast from packaging.version import parse from pydantic import Field, field_validator, model_validator -from pydantic.dataclasses import dataclass from vllm import version from vllm.config.utils import config @@ -16,7 +15,6 @@ DetailedTraceModules = Literal["model", "worker", "all"] @config -@dataclass class ObservabilityConfig: """Configuration for observability - metrics and tracing.""" diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index fa1aa0312..131db50f1 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -3,12 +3,10 @@ import os from collections.abc import Callable -from dataclasses import replace from typing import TYPE_CHECKING, Any, Literal import torch from pydantic import Field, field_validator, model_validator -from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -50,7 +48,6 @@ All2AllBackend = Literal[ @config -@dataclass class EPLBConfig: """Configuration for Expert Parallel Load Balancing (EP).""" @@ -94,7 +91,6 @@ class EPLBConfig: @config -@dataclass class ParallelConfig: """Configuration for the distributed execution.""" @@ -715,6 +711,3 @@ class ParallelConfig: ) return self - - def replace(self, **kwargs) -> Self: - return replace(self, **kwargs) diff --git a/vllm/config/pooler.py b/vllm/config/pooler.py index 6d87ec908..75cdc90fe 100644 --- a/vllm/config/pooler.py +++ b/vllm/config/pooler.py @@ -3,8 +3,6 @@ from typing import Any, Literal, get_args -from pydantic.dataclasses import dataclass - from vllm.config.utils import config from vllm.logger import init_logger from vllm.utils.hashing import safe_hash @@ -19,7 +17,6 @@ TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType) @config -@dataclass class PoolerConfig: """Controls the behavior of output pooling in pooling models.""" diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py index 425f3fb6b..b3b8844f7 100644 --- a/vllm/config/profiler.py +++ b/vllm/config/profiler.py @@ -5,7 +5,6 @@ import os from typing import Any, Literal from pydantic import Field, model_validator -from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.utils import config @@ -32,7 +31,6 @@ def _is_uri_path(path: str) -> bool: @config -@dataclass class ProfilerConfig: """Dataclass which contains profiler config for the engine.""" diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 5ff9fc930..5e44eb84f 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -6,7 +6,6 @@ from dataclasses import InitVar from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast from pydantic import Field, field_validator -from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.utils import config @@ -24,7 +23,6 @@ SchedulerPolicy = Literal["fcfs", "priority"] @config -@dataclass class SchedulerConfig: """Scheduler configuration.""" diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 966d168b4..ed3dbefb3 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -5,7 +5,6 @@ import ast from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator -from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.model import ModelConfig @@ -55,7 +54,6 @@ SpeculativeMethod = Literal[ @config -@dataclass class SpeculativeConfig: """Configuration for speculative decoding.""" diff --git a/vllm/config/speech_to_text.py b/vllm/config/speech_to_text.py index fe3532c97..0233d3657 100644 --- a/vllm/config/speech_to_text.py +++ b/vllm/config/speech_to_text.py @@ -2,13 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from pydantic.dataclasses import dataclass - from vllm.config.utils import config @config -@dataclass class SpeechToTextConfig: """Configuration for speech-to-text models.""" diff --git a/vllm/config/structured_outputs.py b/vllm/config/structured_outputs.py index 8c060c816..c4db15989 100644 --- a/vllm/config/structured_outputs.py +++ b/vllm/config/structured_outputs.py @@ -4,7 +4,6 @@ from typing import Any, Literal from pydantic import model_validator -from pydantic.dataclasses import dataclass from typing_extensions import Self from vllm.config.utils import config @@ -16,7 +15,6 @@ StructuredOutputsBackend = Literal[ @config -@dataclass class StructuredOutputsConfig: """Dataclass which contains structured outputs config for the engine.""" diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 9288948c5..e8c866f02 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -10,14 +10,17 @@ import json import pathlib import textwrap from collections.abc import Callable, Mapping, Sequence, Set -from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace +from dataclasses import MISSING, Field, field, fields, is_dataclass from itertools import pairwise -from typing import TYPE_CHECKING, Any, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast import regex as re import torch +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass +from pydantic.fields import Field as PydanticField from pydantic.fields import FieldInfo -from typing_extensions import runtime_checkable +from typing_extensions import dataclass_transform, runtime_checkable from vllm.logger import init_logger @@ -29,23 +32,39 @@ else: DataclassInstance = Any ConfigType = type[DataclassInstance] -ConfigT = TypeVar("ConfigT", bound=ConfigType) +ConfigT = TypeVar("ConfigT", bound=DataclassInstance) -def config(cls: ConfigT) -> ConfigT: - """ - A decorator that ensures all fields in a dataclass have default values - and that each field has a docstring. +@dataclass_transform(field_specifiers=(PydanticField,)) +def config( + cls: type[ConfigT] | None = None, + *, + config: ConfigDict | None = None, + **kwargs: Any, +) -> type[ConfigT] | Callable[[type[ConfigT]], type[ConfigT]]: + """Decorator to create a pydantic dataclass with default config. The default config + for the dataclass forbids extra fields. - If a `ConfigT` is used as a CLI argument itself, the `type` keyword argument - provided by `get_kwargs` will be - `pydantic.TypeAdapter(ConfigT).validate_json(cli_arg)` which treats the - `cli_arg` as a JSON string which gets validated by `pydantic`. + All config classes in vLLM should use this decorator. - Config validation is performed by the tools/pre_commit/validate_config.py - script, which is invoked during the pre-commit checks. - """ - return cls + Args: + cls: The class to decorate + config: The pydantic ConfigDict to use. If provided, it will be merged with + the default config. + **kwargs: Additional arguments to pass to pydantic.dataclass.""" + # Extra fields are forbidden by default + merged_config = ConfigDict(extra="forbid") + if config is not None: + merged_config.update(config) + + def decorator(cls): + return dataclass(cls, config=merged_config, **kwargs) + + # Called with arguments: @config(config=...) + if cls is None: + return decorator + # Called without arguments: @config + return decorator(cls) def get_field(cls: ConfigType, name: str) -> Field: @@ -53,24 +72,46 @@ def get_field(cls: ConfigType, name: str) -> Field: default factory fields in `EngineArgs`.""" if not is_dataclass(cls): raise TypeError("The given class is not a dataclass.") - cls_fields = {f.name: f for f in fields(cls)} - if name not in cls_fields: - raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields[name] - if (default_factory := named_field.default_factory) is not MISSING: - return field(default_factory=default_factory) - if (default := named_field.default) is not MISSING: - if isinstance(default, FieldInfo): - # Handle pydantic.Field defaults - if default.default_factory is not None: - return field(default_factory=default.default_factory) - else: - default = default.default - return field(default=default) + try: + named_field = next(f for f in fields(cls) if f.name == name) + except StopIteration as e: + raise ValueError(f"Field '{name}' not found in {cls.__name__}.") from e - raise ValueError( - f"{cls.__name__}.{name} must have a default value or default factory." - ) + # The arguments to copy to the new field + default = named_field.default + default_factory = named_field.default_factory + init = named_field.init + + # Handle pydantic.Field + if isinstance(default, FieldInfo): + if default.init is not None: + init = default.init + if default.default_factory is not None: + default_factory = cast(Callable[[], Any], default.default_factory) + default = MISSING + else: + default = default.default + + if default is MISSING and default_factory is MISSING: + logger.warning_once( + "%s.%s has no default or default factory.", cls.__name__, name + ) + return field(default=default, default_factory=default_factory, init=init) + + +def is_init_field(cls: ConfigType, name: str) -> bool: + return get_field(cls, name).init + + +def replace(dataclass_instance: ConfigT, /, **kwargs) -> ConfigT: + """Like [`dataclasses.replace`](https://docs.python.org/3/library/dataclasses.html#dataclasses.replace), + but compatible with Pydantic dataclasses which use `pydantic.fields.Field` instead + of `dataclasses.field`""" + cls = type(dataclass_instance) + dataclass_dict = dataclass_instance.__dict__ + dataclass_dict = {k: v for k, v in dataclass_dict.items() if is_init_field(cls, k)} + dataclass_dict.update(kwargs) + return cls(**dataclass_dict) def getattr_iter( @@ -172,10 +213,6 @@ def get_attr_docs(cls: type[Any]) -> dict[str, str]: return out -def is_init_field(cls: ConfigType, name: str) -> bool: - return next(f for f in fields(cls) if f.name == name).init - - @runtime_checkable class SupportsHash(Protocol): def compute_hash(self) -> str: ... diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ea1338563..846ed50e0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -9,7 +9,7 @@ import tempfile import threading import time from contextlib import contextmanager -from dataclasses import is_dataclass, replace +from dataclasses import is_dataclass from datetime import datetime from enum import IntEnum from functools import lru_cache @@ -18,10 +18,8 @@ from typing import TYPE_CHECKING, Any, TypeVar, get_args import torch from pydantic import ConfigDict, Field, model_validator -from pydantic.dataclasses import dataclass import vllm.envs as envs -from vllm.config.speculative import EagleModelTypes from vllm.logger import enable_trace_function_call, init_logger from vllm.transformers_utils.runai_utils import is_runai_obj_uri from vllm.utils import random_uuid @@ -41,9 +39,9 @@ from .observability import ObservabilityConfig from .parallel import ParallelConfig from .profiler import ProfilerConfig from .scheduler import SchedulerConfig -from .speculative import SpeculativeConfig +from .speculative import EagleModelTypes, SpeculativeConfig from .structured_outputs import StructuredOutputsConfig -from .utils import SupportsHash, config +from .utils import SupportsHash, config, replace if TYPE_CHECKING: from transformers import PretrainedConfig @@ -187,8 +185,7 @@ OPTIMIZATION_LEVEL_TO_CONFIG = { } -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) +@config(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. @@ -1395,14 +1392,6 @@ class VllmConfig: path = self.compilation_config.debug_dump_path / append_path return path - def replace(self, **kwargs): - """ - Replace attributes of the config, and 'recompute' the config. - dataclass.replace() calls __init__() and __post_init__(), source: - https://docs.python.org/3/library/dataclasses.html#dataclasses.replace - """ - return replace(self, **kwargs) - def __str__(self): return ( f"model={self.model_config.model!r}, " diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 808c2a908..983040a89 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -13,8 +13,6 @@ from collections.abc import Sequence from dataclasses import field from typing import Any, Literal -from pydantic.dataclasses import dataclass - import vllm.envs as envs from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type @@ -69,7 +67,6 @@ class LoRAParserAction(argparse.Action): @config -@dataclass class FrontendArgs: """Arguments for the OpenAI-compatible frontend server.""" diff --git a/vllm/transformers_utils/configs/speculators/algos.py b/vllm/transformers_utils/configs/speculators/algos.py index 88bce3d4f..60bb5d588 100644 --- a/vllm/transformers_utils/configs/speculators/algos.py +++ b/vllm/transformers_utils/configs/speculators/algos.py @@ -13,9 +13,10 @@ def register_speculator(name): @register_speculator("eagle3") -def update_eagle3(config_dict: dict, vllm_config: dict) -> None: +def update_eagle3(config_dict: dict, pre_trained_config: dict) -> None: """ - Apply Eagle-3 specific configuration transformations. + Apply Eagle-3 specific configuration transformations to the `dict` used to + construct the Transformers PreTrainedConfig. Eagle-3 specific fields: - draft_vocab_size: Size of the draft model's vocabulary @@ -27,12 +28,14 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None: predictions. This is the standard field used in Eagle3 checkpoints. """ - vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") + pre_trained_config["draft_vocab_size"] = config_dict.get("draft_vocab_size") if config_dict.get("target_hidden_size") is not None: - vllm_config["target_hidden_size"] = config_dict["target_hidden_size"] - vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True) - vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"] + pre_trained_config["target_hidden_size"] = config_dict["target_hidden_size"] + pre_trained_config["norm_before_residual"] = config_dict.get( + "norm_before_residual", True + ) + pre_trained_config["architectures"] = ["Eagle3LlamaForCausalLM"] if config_dict.get("eagle_aux_hidden_state_layer_ids"): - vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ + pre_trained_config["eagle_aux_hidden_state_layer_ids"] = config_dict[ "eagle_aux_hidden_state_layer_ids" ] diff --git a/vllm/transformers_utils/configs/speculators/base.py b/vllm/transformers_utils/configs/speculators/base.py index bf3a5d413..a57350b09 100644 --- a/vllm/transformers_utils/configs/speculators/base.py +++ b/vllm/transformers_utils/configs/speculators/base.py @@ -24,13 +24,16 @@ class SpeculatorsConfig(PretrainedConfig): """Load speculators Eagle config and convert to vLLM format.""" config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - vllm_config = cls.extract_vllm_speculative_config(config_dict) + vllm_config = cls.extract_transformers_pre_trained_config(config_dict) return cls(**vllm_config) @classmethod - def extract_vllm_speculative_config( + def extract_transformers_pre_trained_config( cls, config_dict: dict[str, Any] ) -> dict[str, Any]: + """ + Extract standard Transformers PreTrainedConfig config from speculators config. + """ speculators_model_type = config_dict.get("speculators_model_type") if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES: raise ValueError( @@ -38,15 +41,23 @@ class SpeculatorsConfig(PretrainedConfig): "Please ensure you're loading a speculators-format model." ) + # Start with transformer layer configuration if present + pre_trained_config = config_dict.get("transformer_layer_config", {}) + # Apply anything specific to the supported algorithm + algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] + algo_updater(config_dict=config_dict, pre_trained_config=pre_trained_config) + return pre_trained_config + + @classmethod + def extract_vllm_speculative_config( + cls, config_dict: dict[str, Any] + ) -> dict[str, Any]: + """Extract vLLM speculative config from speculators config.""" # validate fields # TODO: @dsikka - use speculators pydantic model to validate cls.validate_speculators_config(config_dict=config_dict) # Convert from speculators config -> format that can be ingested by vLLM - vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict) - # Apply anything specific to the supported algorithm - algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type] - algo_updater(config_dict=config_dict, vllm_config=vllm_config) - return vllm_config + return cls.build_vllm_speculative_config(config_dict=config_dict) @classmethod def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None: @@ -101,14 +112,7 @@ class SpeculatorsConfig(PretrainedConfig): ) # Build base vLLM speculative configuration - vllm_config = { + return { "method": config_dict.get("speculators_model_type"), "num_speculative_tokens": num_speculative_tokens, - "target_model": spec_config.get("verifier")["name_or_path"], } - - # Merge transformer layer configuration if present - transformer_config = config_dict.get("transformer_layer_config", {}) - vllm_config.update(transformer_config) - - return vllm_config diff --git a/vllm/v1/spec_decode/draft_model.py b/vllm/v1/spec_decode/draft_model.py index 9c6754013..18e98b267 100644 --- a/vllm/v1/spec_decode/draft_model.py +++ b/vllm/v1/spec_decode/draft_model.py @@ -4,7 +4,7 @@ from typing import Any import torch -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig, get_layers_from_vllm_config, replace from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.model_loader import get_model @@ -191,10 +191,12 @@ def create_vllm_config_for_draft_model( old = target_model_vllm_config assert old.speculative_config is not None, "speculative_config is not set" old_spec_config = old.speculative_config - new_parallel_config = old_spec_config.draft_parallel_config.replace( - rank=old.parallel_config.rank + new_parallel_config = replace( + old_spec_config.draft_parallel_config, + rank=old.parallel_config.rank, ) - new: VllmConfig = old.replace( + new: VllmConfig = replace( + old, quant_config=None, # quant_config is recomputed in __init__() model_config=old_spec_config.draft_model_config, parallel_config=new_parallel_config,