Turn @config into a dataclass_transform (#31541)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-02-03 17:40:59 +00:00
committed by GitHub
parent b1bb18de8d
commit 61e632aea1
32 changed files with 153 additions and 191 deletions

View File

@@ -3,11 +3,11 @@
import json
from argparse import ArgumentError
from contextlib import nullcontext
from dataclasses import dataclass, field
from contextlib import AbstractContextManager, nullcontext
from typing import Annotated, Literal
import pytest
from pydantic import Field
from vllm.config import AttentionConfig, CompilationConfig, config
from vllm.engine.arg_utils import (
@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
],
)
def test_literal_to_kwargs(type_hints, expected):
context = nullcontext()
context: AbstractContextManager[object] = nullcontext()
if expected is Exception:
context = pytest.raises(expected)
with context:
@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
@@ -119,23 +117,23 @@ class DummyConfig:
"""Optional bool with default None"""
optional_literal: Literal["x", "y"] | None = None
"""Optional literal with default None"""
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
"""Tuple with variable length"""
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
"""Tuple with fixed length"""
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
"""List with variable length"""
list_literal: list[Literal[1, 2]] = field(default_factory=list)
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
"""List with literal choices"""
list_union: list[str | type[object]] = field(default_factory=list)
list_union: list[str | type[object]] = Field(default_factory=list)
"""List with union type"""
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
"""Set with variable length"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
json_tip: dict = Field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
nested_config: NestedConfig = Field(default_factory=NestedConfig)
"""Nested config"""
@@ -195,7 +193,7 @@ def test_get_kwargs():
json_tip = "Should either be a valid JSON string or JSON keys"
assert json_tip in kwargs["json_tip"]["help"]
# nested config should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
@pytest.mark.parametrize(