Improve configs - the rest! (#17562)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-09 23:18:44 +01:00
committed by GitHub
parent 7e3571134f
commit 4b2ed7926a
14 changed files with 456 additions and 340 deletions

View File

@@ -8,21 +8,18 @@ from typing import Literal, Optional
import pytest
from vllm.config import config
from vllm.config import CompilationConfig, config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
optional_type)
optional_type, parse_type)
from vllm.utils import FlexibleArgumentParser
@pytest.mark.parametrize(("type", "value", "expected"), [
(int, "42", 42),
(int, "None", None),
(float, "3.14", 3.14),
(float, "None", None),
(str, "Hello World!", "Hello World!"),
(str, "None", None),
(json.loads, '{"foo":1,"bar":2}', {
"foo": 1,
"bar": 2
@@ -31,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser
"foo": 1,
"bar": 2
}),
(json.loads, "None", None),
])
def test_optional_type(type, value, expected):
optional_type_func = optional_type(type)
def test_parse_type(type, value, expected):
parse_type_func = parse_type(type)
context = nullcontext()
if value == "foo=1,bar=2":
context = pytest.warns(DeprecationWarning)
with context:
assert optional_type_func(value) == expected
assert parse_type_func(value) == expected
def test_optional_type():
optional_type_func = optional_type(int)
assert optional_type_func("None") is None
assert optional_type_func("42") == 42
@pytest.mark.parametrize(("type_hint", "type", "expected"), [
@@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected):
@config
@dataclass
class DummyConfigClass:
class NestedConfig:
field: int = 1
"""field"""
@config
@dataclass
class FromCliConfig1:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 1
return inst
@config
@dataclass
class FromCliConfig2:
field: int = 1
"""field"""
@classmethod
def from_cli(cls, cli_value: str):
inst = cls(**json.loads(cli_value))
inst.field += 2
return inst
@config
@dataclass
class DummyConfig:
regular_bool: bool = True
"""Regular bool with default True"""
optional_bool: Optional[bool] = None
@@ -108,18 +143,24 @@ class DummyConfigClass:
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI"""
nested_config: NestedConfig = field(default_factory=NestedConfig)
"""Nested config"""
from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1)
"""Config with from_cli method"""
from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2)
"""Different config with from_cli method"""
@pytest.mark.parametrize(("type_hint", "expected"), [
(int, False),
(DummyConfigClass, True),
(DummyConfig, True),
])
def test_is_not_builtin(type_hint, expected):
assert is_not_builtin(type_hint) == expected
def test_get_kwargs():
kwargs = get_kwargs(DummyConfigClass)
kwargs = get_kwargs(DummyConfig)
print(kwargs)
# bools should not have their type set
@@ -142,6 +183,11 @@ def test_get_kwargs():
# dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip)
# nested config should should construct the nested config
assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
# from_cli configs should be constructed with the correct method
assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
@pytest.mark.parametrize(("arg", "expected"), [
@@ -177,7 +223,7 @@ def test_compilation_config():
# default value
args = parser.parse_args([])
assert args.compilation_config is None
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O3"])
@@ -194,7 +240,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config",
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
@@ -202,7 +248,7 @@ def test_compilation_config():
# set to string form of a dict
args = parser.parse_args([
"--compilation-config="
"{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])