Improve literal dataclass field conversion to argparse argument (#17391)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,14 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
from typing import Literal, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig, get_field
|
||||
from vllm.config import ModelConfig, PoolerConfig, config, get_field
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class TestConfig1:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig2:
|
||||
a: int
|
||||
"""docstring"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig3:
|
||||
a: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestConfig4:
|
||||
a: Union[Literal[1], Literal[2]] = 1
|
||||
"""docstring"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("test_config", "expected_error"), [
|
||||
(TestConfig1, "must be a dataclass"),
|
||||
(TestConfig2, "must have a default"),
|
||||
(TestConfig3, "must have a docstring"),
|
||||
(TestConfig4, "must use a single Literal"),
|
||||
])
|
||||
def test_config(test_config, expected_error):
|
||||
with pytest.raises(Exception, match=expected_error):
|
||||
config(test_config)
|
||||
|
||||
|
||||
def test_get_field():
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user