Turn @config into a dataclass_transform (#31541)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -3,11 +3,11 @@
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.config import AttentionConfig, CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (
|
||||
@@ -96,7 +96,7 @@ def test_get_type(type_hints, type, expected):
|
||||
],
|
||||
)
|
||||
def test_literal_to_kwargs(type_hints, expected):
|
||||
context = nullcontext()
|
||||
context: AbstractContextManager[object] = nullcontext()
|
||||
if expected is Exception:
|
||||
context = pytest.raises(expected)
|
||||
with context:
|
||||
@@ -104,14 +104,12 @@ def test_literal_to_kwargs(type_hints, expected):
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class NestedConfig:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
@@ -119,23 +117,23 @@ class DummyConfig:
|
||||
"""Optional bool with default None"""
|
||||
optional_literal: Literal["x", "y"] | None = None
|
||||
"""Optional literal with default None"""
|
||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||
tuple_n: tuple[int, ...] = Field(default_factory=lambda: (1, 2, 3))
|
||||
"""Tuple with variable length"""
|
||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||
tuple_2: tuple[int, int] = Field(default_factory=lambda: (1, 2))
|
||||
"""Tuple with fixed length"""
|
||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||
list_n: list[int] = Field(default_factory=lambda: [1, 2, 3])
|
||||
"""List with variable length"""
|
||||
list_literal: list[Literal[1, 2]] = field(default_factory=list)
|
||||
list_literal: list[Literal[1, 2]] = Field(default_factory=list)
|
||||
"""List with literal choices"""
|
||||
list_union: list[str | type[object]] = field(default_factory=list)
|
||||
list_union: list[str | type[object]] = Field(default_factory=list)
|
||||
"""List with union type"""
|
||||
set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
|
||||
set_n: set[int] = Field(default_factory=lambda: {1, 2, 3})
|
||||
"""Set with variable length"""
|
||||
literal_literal: Literal[Literal[1], Literal[2]] = 1
|
||||
"""Literal of literals with default 1"""
|
||||
json_tip: dict = field(default_factory=dict)
|
||||
json_tip: dict = Field(default_factory=dict)
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
nested_config: NestedConfig = Field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
|
||||
|
||||
@@ -195,7 +193,7 @@ def test_get_kwargs():
|
||||
json_tip = "Should either be a valid JSON string or JSON keys"
|
||||
assert json_tip in kwargs["json_tip"]["help"]
|
||||
# nested config should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # type: ignore[call-arg]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user