[Chore] Separate out vllm.utils.func (#26904)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
97
tests/utils_/test_func_utils.py
Normal file
97
tests/utils_/test_func_utils.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.func import deprecate_kwargs, supports_kw
|
||||
|
||||
from ..utils import error_on_warning
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("callable", "kw_name", "requires_kw_only", "allow_var_kwargs", "is_supported"),
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
],
|
||||
)
|
||||
def test_supports_kw(
|
||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||
):
|
||||
assert (
|
||||
supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
== is_supported
|
||||
)
|
||||
32
tests/utils_/test_jsontree.py
Normal file
32
tests/utils_/test_jsontree.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.utils.jsontree import json_count_leaves
|
||||
|
||||
|
||||
def test_json_count_leaves():
|
||||
"""Test json_count_leaves function from jsontree utility."""
|
||||
|
||||
# Single leaf values
|
||||
assert json_count_leaves(42) == 1
|
||||
assert json_count_leaves("hello") == 1
|
||||
assert json_count_leaves(None) == 1
|
||||
|
||||
# Empty containers
|
||||
assert json_count_leaves([]) == 0
|
||||
assert json_count_leaves({}) == 0
|
||||
assert json_count_leaves(()) == 0
|
||||
|
||||
# Flat structures
|
||||
assert json_count_leaves([1, 2, 3]) == 3
|
||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||
assert json_count_leaves((1, 2, 3)) == 3
|
||||
|
||||
# Nested structures
|
||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
assert json_count_leaves(nested_dict) == 3
|
||||
|
||||
nested_list = [1, [2, 3], 4]
|
||||
assert json_count_leaves(nested_list) == 4
|
||||
|
||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||
assert json_count_leaves(mixed_nested) == 4
|
||||
@@ -30,7 +30,6 @@ from vllm.utils import (
|
||||
bind_kv_cache,
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
deprecate_kwargs,
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
is_lossless_cast,
|
||||
@@ -42,12 +41,11 @@ from vllm.utils import (
|
||||
sha256,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
supports_kw,
|
||||
swap_dict_values,
|
||||
unique_filepath,
|
||||
)
|
||||
|
||||
from ..utils import create_new_process_for_each_test, error_on_warning
|
||||
from ..utils import create_new_process_for_each_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -83,61 +81,6 @@ async def test_merge_async_iterators():
|
||||
raise AssertionError() from e
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning(DeprecationWarning):
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
@@ -383,39 +326,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
assert "-O.mode" in caplog_vllm.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
||||
[
|
||||
# Tests for positional argument support
|
||||
(lambda foo: None, "foo", True, True, False),
|
||||
(lambda foo: None, "foo", False, True, True),
|
||||
# Tests for positional or keyword / keyword only
|
||||
(lambda foo=100: None, "foo", True, True, False),
|
||||
(lambda *, foo: None, "foo", False, True, True),
|
||||
# Tests to make sure the names of variadic params are NOT supported
|
||||
(lambda *args: None, "args", False, True, False),
|
||||
(lambda **kwargs: None, "kwargs", False, True, False),
|
||||
# Tests for if we allow var kwargs to add support
|
||||
(lambda foo: None, "something_else", False, True, False),
|
||||
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
||||
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
||||
(lambda foo, **kwargs: None, "foo", True, True, False),
|
||||
],
|
||||
)
|
||||
def test_supports_kw(
|
||||
callable, kw_name, requires_kw_only, allow_var_kwargs, is_supported
|
||||
):
|
||||
assert (
|
||||
supports_kw(
|
||||
callable=callable,
|
||||
kw_name=kw_name,
|
||||
requires_kw_only=requires_kw_only,
|
||||
allow_var_kwargs=allow_var_kwargs,
|
||||
)
|
||||
== is_supported
|
||||
)
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_memory_profiling():
|
||||
# Fake out some model loading + inference memory usage to test profiling
|
||||
@@ -863,36 +773,6 @@ def test_join_host_port():
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
|
||||
|
||||
def test_json_count_leaves():
|
||||
"""Test json_count_leaves function from jsontree utility."""
|
||||
from vllm.utils.jsontree import json_count_leaves
|
||||
|
||||
# Single leaf values
|
||||
assert json_count_leaves(42) == 1
|
||||
assert json_count_leaves("hello") == 1
|
||||
assert json_count_leaves(None) == 1
|
||||
|
||||
# Empty containers
|
||||
assert json_count_leaves([]) == 0
|
||||
assert json_count_leaves({}) == 0
|
||||
assert json_count_leaves(()) == 0
|
||||
|
||||
# Flat structures
|
||||
assert json_count_leaves([1, 2, 3]) == 3
|
||||
assert json_count_leaves({"a": 1, "b": 2}) == 2
|
||||
assert json_count_leaves((1, 2, 3)) == 3
|
||||
|
||||
# Nested structures
|
||||
nested_dict = {"a": 1, "b": {"c": 2, "d": 3}}
|
||||
assert json_count_leaves(nested_dict) == 3
|
||||
|
||||
nested_list = [1, [2, 3], 4]
|
||||
assert json_count_leaves(nested_list) == 4
|
||||
|
||||
mixed_nested = {"list": [1, 2], "dict": {"x": 3}, "value": 4}
|
||||
assert json_count_leaves(mixed_nested) == 4
|
||||
|
||||
|
||||
def test_convert_ids_list_to_tokens():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||
token_ids = tokenizer.encode("Hello, world!")
|
||||
|
||||
Reference in New Issue
Block a user