[Chore] Separate out vllm.utils.func (#26904)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-15 21:03:58 +08:00
committed by GitHub
parent f57438338d
commit 136a17fe6e
18 changed files with 407 additions and 371 deletions

View File

@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
import pytest
from vllm.utils.func import deprecate_kwargs, supports_kw
from ..utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
@pytest.mark.parametrize(
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
],
)
def test_supports_kw(
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
):
assert (
supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
== is_supported
)

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.utils.jsontree import json_count_leaves
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4

View File

@@ -30,7 +30,6 @@ from vllm.utils import (
bind_kv_cache,
common_broadcastable_dtype,
current_stream,
deprecate_kwargs,
get_open_port,
get_tcp_uri,
is_lossless_cast,
@@ -42,12 +41,11 @@ from vllm.utils import (
sha256,
split_host_port,
split_zmq_path,
supports_kw,
swap_dict_values,
unique_filepath,
)
from ..utils import create_new_process_for_each_test, error_on_warning
from ..utils import create_new_process_for_each_test
@pytest.mark.asyncio
@@ -83,61 +81,6 @@ async def test_merge_async_iterators():
raise AssertionError() from e
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PORT", "5678")
@@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
assert "-O.mode" in caplog_vllm.text
@pytest.mark.parametrize(
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
[
# Tests for positional argument support
(lambda foo: None, "foo", True, True, False),
(lambda foo: None, "foo", False, True, True),
# Tests for positional or keyword / keyword only
(lambda foo=100: None, "foo", True, True, False),
(lambda *, foo: None, "foo", False, True, True),
# Tests to make sure the names of variadic params are NOT supported
(lambda *args: None, "args", False, True, False),
(lambda **kwargs: None, "kwargs", False, True, False),
# Tests for if we allow var kwargs to add support
(lambda foo: None, "something_else", False, True, False),
(lambda foo, **kwargs: None, "something_else", False, True, True),
(lambda foo, **kwargs: None, "kwargs", True, True, False),
(lambda foo, **kwargs: None, "foo", True, True, False),
],
)
def test_supports_kw(
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
):
assert (
supports_kw(
callable=callable,
kw_name=kw_name,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs,
)
== is_supported
)
@create_new_process_for_each_test()
def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling
@@ -863,36 +773,6 @@ def test_join_host_port():
assert join_host_port("::1", 5555) == "[::1]:5555"
def test_json_count_leaves():
"""Test json_count_leaves function from jsontree utility."""
from vllm.utils.jsontree import json_count_leaves
# Single leaf values
assert json_count_leaves(42) == 1
assert json_count_leaves("hello") == 1
assert json_count_leaves(None) == 1
# Empty containers
assert json_count_leaves([]) == 0
assert json_count_leaves({}) == 0
assert json_count_leaves(()) == 0
# Flat structures
assert json_count_leaves([1, 2, 3]) == 3
assert json_count_leaves({"a": 1, "b": 2}) == 2
assert json_count_leaves((1, 2, 3)) == 3
# Nested structures
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
assert json_count_leaves(nested_dict) == 3
nested_list = [1, [2, 3], 4]
assert json_count_leaves(nested_list) == 4
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
assert json_count_leaves(mixed_nested) == 4
def test_convert_ids_list_to_tokens():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
token_ids = tokenizer.encode("Hello, world!")