[Misc] Move utils to avoid conflicts with stdlib, and move tests (#27169)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,10 +2,8 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -14,7 +12,6 @@ import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import AutoTokenizer
|
||||
from vllm_test_utils.monitor import monitor
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||
@@ -24,13 +21,6 @@ from vllm.utils import (
|
||||
bind_kv_cache,
|
||||
unique_filepath,
|
||||
)
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.utils.torch_utils import (
|
||||
common_broadcastable_dtype,
|
||||
current_stream,
|
||||
is_lossless_cast,
|
||||
)
|
||||
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
|
||||
|
||||
@@ -267,61 +257,6 @@ def test_duplicate_dict_args(caplog_vllm, parser):
|
||||
assert "-O.mode" in caplog_vllm.text
|
||||
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
def test_memory_profiling():
|
||||
# Fake out some model loading + inference memory usage to test profiling
|
||||
# Memory used by other processes will show up as cuda usage outside of torch
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
|
||||
lib = CudaRTLibrary()
|
||||
# 512 MiB allocation outside of this instance
|
||||
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
|
||||
|
||||
baseline_snapshot = MemorySnapshot()
|
||||
|
||||
# load weights
|
||||
|
||||
weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32)
|
||||
|
||||
weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
|
||||
|
||||
def measure_current_non_torch():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current_used = total - free
|
||||
current_torch = torch.cuda.memory_reserved()
|
||||
current_non_torch = current_used - current_torch
|
||||
return current_non_torch
|
||||
|
||||
with (
|
||||
memory_profiling(
|
||||
baseline_snapshot=baseline_snapshot, weights_memory=weights_memory
|
||||
) as result,
|
||||
monitor(measure_current_non_torch) as monitored_values,
|
||||
):
|
||||
# make a memory spike, 1 GiB
|
||||
spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32)
|
||||
del spike
|
||||
|
||||
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
||||
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
||||
|
||||
# this is an analytic value, it is exact,
|
||||
# we only have 256 MiB non-torch memory increase
|
||||
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
|
||||
assert measured_diff == 256 * 1024 * 1024
|
||||
|
||||
# Check that the memory usage is within 5% of the expected values
|
||||
# 5% tolerance is caused by cuda runtime.
|
||||
# we cannot control cuda runtime in the granularity of bytes,
|
||||
# which causes a small error (<10 MiB in practice)
|
||||
non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
|
||||
assert abs(non_torch_ratio - 1) <= 0.05
|
||||
assert result.torch_peak_increase == 1024 * 1024 * 1024
|
||||
del weights
|
||||
lib.cudaFree(handle1)
|
||||
lib.cudaFree(handle2)
|
||||
|
||||
|
||||
def test_bind_kv_cache():
|
||||
from vllm.attention import Attention
|
||||
|
||||
@@ -403,56 +338,6 @@ def test_bind_kv_cache_pp():
|
||||
assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("src_dtype", "tgt_dtype", "expected_result"),
|
||||
[
|
||||
# Different precision_levels
|
||||
(torch.bool, torch.int8, True),
|
||||
(torch.bool, torch.float16, True),
|
||||
(torch.bool, torch.complex32, True),
|
||||
(torch.int64, torch.bool, False),
|
||||
(torch.int64, torch.float16, True),
|
||||
(torch.int64, torch.complex32, True),
|
||||
(torch.float64, torch.bool, False),
|
||||
(torch.float64, torch.int8, False),
|
||||
(torch.float64, torch.complex32, True),
|
||||
(torch.complex128, torch.bool, False),
|
||||
(torch.complex128, torch.int8, False),
|
||||
(torch.complex128, torch.float16, False),
|
||||
# precision_level=0
|
||||
(torch.bool, torch.bool, True),
|
||||
# precision_level=1
|
||||
(torch.int8, torch.int16, True),
|
||||
(torch.int16, torch.int8, False),
|
||||
(torch.uint8, torch.int8, False),
|
||||
(torch.int8, torch.uint8, False),
|
||||
# precision_level=2
|
||||
(torch.float16, torch.float32, True),
|
||||
(torch.float32, torch.float16, False),
|
||||
(torch.bfloat16, torch.float32, True),
|
||||
(torch.float32, torch.bfloat16, False),
|
||||
# precision_level=3
|
||||
(torch.complex32, torch.complex64, True),
|
||||
(torch.complex64, torch.complex32, False),
|
||||
],
|
||||
)
|
||||
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
|
||||
assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dtypes", "expected_result"),
|
||||
[
|
||||
([torch.bool], torch.bool),
|
||||
([torch.bool, torch.int8], torch.int8),
|
||||
([torch.bool, torch.int8, torch.float16], torch.float16),
|
||||
([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32), # noqa: E501
|
||||
],
|
||||
)
|
||||
def test_common_broadcastable_dtype(dtypes, expected_result):
|
||||
assert common_broadcastable_dtype(dtypes) == expected_result
|
||||
|
||||
|
||||
def test_model_specification(
|
||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||
):
|
||||
@@ -535,23 +420,6 @@ def test_model_specification(
|
||||
assert args.port == 12312
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])])
|
||||
def test_sha256(input: tuple):
|
||||
digest = sha256(input)
|
||||
assert digest is not None
|
||||
assert isinstance(digest, bytes)
|
||||
assert digest != b""
|
||||
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert digest == hashlib.sha256(input_bytes).digest()
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert digest == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert digest != sha256(input + (1,))
|
||||
|
||||
|
||||
def test_convert_ids_list_to_tokens():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||
token_ids = tokenizer.encode("Hello, world!")
|
||||
@@ -561,50 +429,6 @@ def test_convert_ids_list_to_tokens():
|
||||
assert tokens == ["Hello", ",", " world", "!"]
|
||||
|
||||
|
||||
def test_current_stream_multithread():
|
||||
import threading
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
main_default_stream = torch.cuda.current_stream()
|
||||
child_stream = torch.cuda.Stream()
|
||||
|
||||
thread_stream_ready = threading.Event()
|
||||
thread_can_exit = threading.Event()
|
||||
|
||||
def child_thread_func():
|
||||
with torch.cuda.stream(child_stream):
|
||||
thread_stream_ready.set()
|
||||
thread_can_exit.wait(timeout=10)
|
||||
|
||||
child_thread = threading.Thread(target=child_thread_func)
|
||||
child_thread.start()
|
||||
|
||||
try:
|
||||
assert thread_stream_ready.wait(timeout=5), (
|
||||
"Child thread failed to enter stream context in time"
|
||||
)
|
||||
|
||||
main_current_stream = current_stream()
|
||||
|
||||
assert main_current_stream != child_stream, (
|
||||
"Main thread's current_stream was contaminated by child thread"
|
||||
)
|
||||
assert main_current_stream == main_default_stream, (
|
||||
"Main thread's current_stream is not the default stream"
|
||||
)
|
||||
|
||||
# Notify child thread it can exit
|
||||
thread_can_exit.set()
|
||||
|
||||
finally:
|
||||
# Ensure child thread exits properly
|
||||
child_thread.join(timeout=5)
|
||||
if child_thread.is_alive():
|
||||
pytest.fail("Child thread failed to exit properly")
|
||||
|
||||
|
||||
def test_load_config_file(tmp_path):
|
||||
# Define the configuration data
|
||||
config_data = {
|
||||
|
||||
Reference in New Issue
Block a user