Improve configs - the rest! (#17562)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,21 +8,18 @@ from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import config
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
literal_to_kwargs, nullable_kvs,
|
||||
optional_type)
|
||||
optional_type, parse_type)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type", "value", "expected"), [
|
||||
(int, "42", 42),
|
||||
(int, "None", None),
|
||||
(float, "3.14", 3.14),
|
||||
(float, "None", None),
|
||||
(str, "Hello World!", "Hello World!"),
|
||||
(str, "None", None),
|
||||
(json.loads, '{"foo":1,"bar":2}', {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
(json.loads, "None", None),
|
||||
])
|
||||
def test_optional_type(type, value, expected):
|
||||
optional_type_func = optional_type(type)
|
||||
def test_parse_type(type, value, expected):
|
||||
parse_type_func = parse_type(type)
|
||||
context = nullcontext()
|
||||
if value == "foo=1,bar=2":
|
||||
context = pytest.warns(DeprecationWarning)
|
||||
with context:
|
||||
assert optional_type_func(value) == expected
|
||||
assert parse_type_func(value) == expected
|
||||
|
||||
|
||||
def test_optional_type():
|
||||
optional_type_func = optional_type(int)
|
||||
assert optional_type_func("None") is None
|
||||
assert optional_type_func("42") == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
|
||||
@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfigClass:
|
||||
class NestedConfig:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig1:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 1
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class FromCliConfig2:
|
||||
field: int = 1
|
||||
"""field"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str):
|
||||
inst = cls(**json.loads(cli_value))
|
||||
inst.field += 2
|
||||
return inst
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class DummyConfig:
|
||||
regular_bool: bool = True
|
||||
"""Regular bool with default True"""
|
||||
optional_bool: Optional[bool] = None
|
||||
@@ -108,18 +143,24 @@ class DummyConfigClass:
|
||||
"""Literal of literals with default 1"""
|
||||
json_tip: dict = field(default_factory=dict)
|
||||
"""Dict which will be JSON in CLI"""
|
||||
nested_config: NestedConfig = field(default_factory=NestedConfig)
|
||||
"""Nested config"""
|
||||
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
|
||||
"""Config with from_cli method"""
|
||||
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
|
||||
"""Different config with from_cli method"""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("type_hint", "expected"), [
|
||||
(int, False),
|
||||
(DummyConfigClass, True),
|
||||
(DummyConfig, True),
|
||||
])
|
||||
def test_is_not_builtin(type_hint, expected):
|
||||
assert is_not_builtin(type_hint) == expected
|
||||
|
||||
|
||||
def test_get_kwargs():
|
||||
kwargs = get_kwargs(DummyConfigClass)
|
||||
kwargs = get_kwargs(DummyConfig)
|
||||
print(kwargs)
|
||||
|
||||
# bools should not have their type set
|
||||
@@ -142,6 +183,11 @@ def test_get_kwargs():
|
||||
# dict should have json tip in help
|
||||
json_tip = "\n\nShould be a valid JSON string."
|
||||
assert kwargs["json_tip"]["help"].endswith(json_tip)
|
||||
# nested config should should construct the nested config
|
||||
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
|
||||
# from_cli configs should be constructed with the correct method
|
||||
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
|
||||
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
@@ -177,7 +223,7 @@ def test_compilation_config():
|
||||
|
||||
# default value
|
||||
args = parser.parse_args([])
|
||||
assert args.compilation_config is None
|
||||
assert args.compilation_config == CompilationConfig()
|
||||
|
||||
# set to O3
|
||||
args = parser.parse_args(["-O3"])
|
||||
@@ -194,7 +240,7 @@ def test_compilation_config():
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config",
|
||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
@@ -202,7 +248,7 @@ def test_compilation_config():
|
||||
# set to string form of a dict
|
||||
args = parser.parse_args([
|
||||
"--compilation-config="
|
||||
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
|
||||
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
|
||||
])
|
||||
assert (args.compilation_config.level == 3 and
|
||||
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
|
||||
|
||||
Reference in New Issue
Block a user