[Core] Rework dtype resolution (#18751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-06-01 11:04:23 +08:00
committed by GitHub
parent 1bc86a3da1
commit 6aa8f9a4e7
13 changed files with 314 additions and 119 deletions

View File

@@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, deprecate_kwargs, get_open_port,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
@@ -567,12 +568,65 @@ def test_lru_cache():
assert 6 in cache
# yapf: disable
@pytest.mark.parametrize(
("src_dtype", "tgt_dtype", "expected_result"),
[
# Different precision_levels
(torch.bool, torch.int8, True),
(torch.bool, torch.float16, True),
(torch.bool, torch.complex32, True),
(torch.int64, torch.bool, False),
(torch.int64, torch.float16, True),
(torch.int64, torch.complex32, True),
(torch.float64, torch.bool, False),
(torch.float64, torch.int8, False),
(torch.float64, torch.complex32, True),
(torch.complex128, torch.bool, False),
(torch.complex128, torch.int8, False),
(torch.complex128, torch.float16, False),
# precision_level=0
(torch.bool, torch.bool, True),
# precision_level=1
(torch.int8, torch.int16, True),
(torch.int16, torch.int8, False),
(torch.uint8, torch.int8, False),
(torch.int8, torch.uint8, False),
# precision_level=2
(torch.float16, torch.float32, True),
(torch.float32, torch.float16, False),
(torch.bfloat16, torch.float32, True),
(torch.float32, torch.bfloat16, False),
# precision_level=3
(torch.complex32, torch.complex64, True),
(torch.complex64, torch.complex32, False),
],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
# yapf: disable
@pytest.mark.parametrize(
("dtypes", "expected_result"),
[
([torch.bool], torch.bool),
([torch.bool, torch.int8], torch.int8),
([torch.bool, torch.int8, torch.float16], torch.float16),
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
assert common_broadcastable_dtype(dtypes) == expected_result
def test_placeholder_module_error_handling():
placeholder = PlaceholderModule("placeholder_1234")
def build_ctx():
return pytest.raises(ModuleNotFoundError,
match="No module named")
return pytest.raises(ModuleNotFoundError, match="No module named")
with build_ctx():
int(placeholder)
@@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
_ = placeholder_attr.module
# yapf: disable
@pytest.mark.parametrize(
"obj,key1,key2",
[
@@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
])
# yapf: enable
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
@@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
assert key1 not in obj
def test_model_specification(parser_with_config,
cli_config_file,
def test_model_specification(parser_with_config, cli_config_file,
cli_config_file_with_model):
# Test model in CLI takes precedence over config
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model
])
args = parser_with_config.parse_args(
['serve', 'cli-model', '--config', cli_config_file_with_model])
assert args.model_tag == 'cli-model'
assert args.served_model_name == 'mymodel'
# Test model from config file works
args = parser_with_config.parse_args([
'serve', '--config', cli_config_file_with_model,
'serve',
'--config',
cli_config_file_with_model,
])
assert args.model == 'config-model'
assert args.served_model_name == 'mymodel'
@@ -654,17 +710,19 @@ def test_model_specification(parser_with_config,
# Test using --model option raises error
with pytest.raises(
ValueError,
match=(
"With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."
),
ValueError,
match=
("With `vllm serve`, you should provide the model as a positional "
"argument or in a config file instead of via the `--model` option."),
):
parser_with_config.parse_args(['serve', '--model', 'my-model'])
# Test other config values are preserved
args = parser_with_config.parse_args([
'serve', 'cli-model', '--config', cli_config_file_with_model,
'serve',
'cli-model',
'--config',
cli_config_file_with_model,
])
assert args.tensor_parallel_size == 2
assert args.trust_remote_code is True
@@ -673,7 +731,7 @@ def test_model_specification(parser_with_config,
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
hash = sha256(input)
@@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
byteorder="big")
# hashing again, returns the same value
assert hash == sha256(input)
@@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
("inproc://some_identifier", ("inproc", "some_identifier", "")),
]
)
])
def test_split_zmq_path(path, expected):
assert split_zmq_path(path) == expected
@@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
"tcp://127.0.0.1", # Missing port
"tcp://[::1]", # Missing port for IPv6
"tcp://:5555", # Missing host
]
)
])
def test_split_zmq_path_invalid(invalid_path):
with pytest.raises(ValueError):
split_zmq_path(invalid_path)
@@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
# Verify that the IPV6 option is set
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
assert zsock.getsockopt(
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
# Clean up
zsock.close()