Improve literal dataclass field conversion to argparse argument (#17391)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,8 @@ import pytest
|
||||
from vllm.config import PoolerConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
nullable_kvs, optional_type)
|
||||
literal_to_kwargs, nullable_kvs,
|
||||
optional_type)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@@ -71,6 +72,21 @@ def test_get_type(type_hints, type, expected):
|
||||
assert get_type(type_hints, type) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hints", "expected"), [
|
||||
({Literal[1, 2]}, {
|
||||
"type": int,
|
||||
"choices": [1, 2]
|
||||
}),
|
||||
({Literal[1, "a"]}, Exception),
|
||||
])
|
||||
def test_literal_to_kwargs(type_hints, expected):
|
||||
context = nullcontext()
|
||||
if expected is Exception:
|
||||
context = pytest.raises(expected)
|
||||
with context:
|
||||
assert literal_to_kwargs(type_hints) == expected
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfigClass:
|
||||
@@ -81,11 +97,15 @@ class DummyConfigClass:
|
||||
optional_literal: Optional[Literal["x", "y"]] = None
|
||||
"""Optional literal with default None"""
|
||||
tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
|
||||
"""Tuple with default (1, 2, 3)"""
|
||||
"""Tuple with variable length"""
|
||||
tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
|
||||
"""Tuple with default (1, 2)"""
|
||||
"""Tuple with fixed length"""
|
||||
list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
|
||||
"""List with default [1, 2, 3]"""
|
||||
"""List with variable length"""
|
||||
list_literal: list[Literal[1, 2]] = field(default_factory=list)
|
||||
"""List with literal choices"""
|
||||
literal_literal: Literal[Literal[1], Literal[2]] = 1
|
||||
"""Literal of literals with default 1"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
@@ -111,6 +131,12 @@ def test_get_kwargs():
|
||||
# lists should work
|
||||
assert kwargs["list_n"]["type"] is int
|
||||
assert kwargs["list_n"]["nargs"] == "+"
|
||||
# lists with literals should have the correct choices
|
||||
assert kwargs["list_literal"]["type"] is int
|
||||
assert kwargs["list_literal"]["nargs"] == "+"
|
||||
assert kwargs["list_literal"]["choices"] == [1, 2]
|
||||
# literals of literals should have merged choices
|
||||
assert kwargs["literal_literal"]["choices"] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
|
||||
Reference in New Issue
Block a user