[Core] Rework dtype resolution (#18751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -17,7 +17,8 @@ from vllm_test_utils.monitor import monitor
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, deprecate_kwargs, get_open_port,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
@@ -567,12 +568,65 @@ def test_lru_cache():
|
||||
assert 6 in cache
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("src_dtype", "tgt_dtype", "expected_result"),
|
||||
[
|
||||
# Different precision_levels
|
||||
(torch.bool, torch.int8, True),
|
||||
(torch.bool, torch.float16, True),
|
||||
(torch.bool, torch.complex32, True),
|
||||
(torch.int64, torch.bool, False),
|
||||
(torch.int64, torch.float16, True),
|
||||
(torch.int64, torch.complex32, True),
|
||||
(torch.float64, torch.bool, False),
|
||||
(torch.float64, torch.int8, False),
|
||||
(torch.float64, torch.complex32, True),
|
||||
(torch.complex128, torch.bool, False),
|
||||
(torch.complex128, torch.int8, False),
|
||||
(torch.complex128, torch.float16, False),
|
||||
# precision_level=0
|
||||
(torch.bool, torch.bool, True),
|
||||
# precision_level=1
|
||||
(torch.int8, torch.int16, True),
|
||||
(torch.int16, torch.int8, False),
|
||||
(torch.uint8, torch.int8, False),
|
||||
(torch.int8, torch.uint8, False),
|
||||
# precision_level=2
|
||||
(torch.float16, torch.float32, True),
|
||||
(torch.float32, torch.float16, False),
|
||||
(torch.bfloat16, torch.float32, True),
|
||||
(torch.float32, torch.bfloat16, False),
|
||||
# precision_level=3
|
||||
(torch.complex32, torch.complex64, True),
|
||||
(torch.complex64, torch.complex32, False),
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
|
||||
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("dtypes", "expected_result"),
|
||||
[
|
||||
([torch.bool], torch.bool),
|
||||
([torch.bool, torch.int8], torch.int8),
|
||||
([torch.bool, torch.int8, torch.float16], torch.float16),
|
||||
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_common_broadcastable_dtype(dtypes, expected_result):
|
||||
assert common_broadcastable_dtype(dtypes) == expected_result
|
||||
|
||||
|
||||
def test_placeholder_module_error_handling():
|
||||
placeholder = PlaceholderModule("placeholder_1234")
|
||||
|
||||
def build_ctx():
|
||||
return pytest.raises(ModuleNotFoundError,
|
||||
match="No module named")
|
||||
return pytest.raises(ModuleNotFoundError, match="No module named")
|
||||
|
||||
with build_ctx():
|
||||
int(placeholder)
|
||||
@@ -608,6 +662,7 @@ def test_placeholder_module_error_handling():
|
||||
_ = placeholder_attr.module
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
@@ -618,6 +673,7 @@ def test_placeholder_module_error_handling():
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
swap_dict_values(obj, key1, key2)
|
||||
@@ -631,19 +687,19 @@ def test_swap_dict_values(obj, key1, key2):
|
||||
assert key1 not in obj
|
||||
|
||||
|
||||
def test_model_specification(parser_with_config,
|
||||
cli_config_file,
|
||||
def test_model_specification(parser_with_config, cli_config_file,
|
||||
cli_config_file_with_model):
|
||||
# Test model in CLI takes precedence over config
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model
|
||||
])
|
||||
args = parser_with_config.parse_args(
|
||||
['serve', 'cli-model', '--config', cli_config_file_with_model])
|
||||
assert args.model_tag == 'cli-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
|
||||
# Test model from config file works
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.model == 'config-model'
|
||||
assert args.served_model_name == 'mymodel'
|
||||
@@ -654,17 +710,19 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
# Test using --model option raises error
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=(
|
||||
"With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."
|
||||
),
|
||||
ValueError,
|
||||
match=
|
||||
("With `vllm serve`, you should provide the model as a positional "
|
||||
"argument or in a config file instead of via the `--model` option."),
|
||||
):
|
||||
parser_with_config.parse_args(['serve', '--model', 'my-model'])
|
||||
|
||||
# Test other config values are preserved
|
||||
args = parser_with_config.parse_args([
|
||||
'serve', 'cli-model', '--config', cli_config_file_with_model,
|
||||
'serve',
|
||||
'cli-model',
|
||||
'--config',
|
||||
cli_config_file_with_model,
|
||||
])
|
||||
assert args.tensor_parallel_size == 2
|
||||
assert args.trust_remote_code is True
|
||||
@@ -673,7 +731,7 @@ def test_model_specification(parser_with_config,
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
@@ -682,7 +740,8 @@ def test_sha256(input: tuple, output: int):
|
||||
assert hash != 0
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
|
||||
byteorder="big")
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
@@ -698,8 +757,7 @@ def test_sha256(input: tuple, output: int):
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
@@ -711,8 +769,7 @@ def test_split_zmq_path(path, expected):
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
]
|
||||
)
|
||||
])
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
@@ -734,7 +791,8 @@ def test_make_zmq_socket_ipv6():
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
assert zsock.getsockopt(
|
||||
zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
|
||||
Reference in New Issue
Block a user