Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
155
vllm/v1/utils.py
155
vllm/v1/utils.py
@@ -9,25 +9,35 @@ from collections.abc import Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
|
||||
kill_process_tree)
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
|
||||
from vllm.utils import (
|
||||
get_open_port,
|
||||
get_open_zmq_ipc_path,
|
||||
get_tcp_uri,
|
||||
kill_process_tree,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -35,7 +45,6 @@ T = TypeVar("T")
|
||||
|
||||
|
||||
class ConstantList(Generic[T], Sequence):
|
||||
|
||||
def __init__(self, x: list[T]) -> None:
|
||||
self._x = x
|
||||
|
||||
@@ -57,31 +66,23 @@ class ConstantList(Generic[T], Sequence):
|
||||
def clear(self):
|
||||
raise TypeError("Cannot clear a constant list")
|
||||
|
||||
def index(self,
|
||||
item: T,
|
||||
start: int = 0,
|
||||
stop: Optional[int] = None) -> int:
|
||||
return self._x.index(item, start,
|
||||
stop if stop is not None else len(self._x))
|
||||
def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int:
|
||||
return self._x.index(item, start, stop if stop is not None else len(self._x))
|
||||
|
||||
@overload
|
||||
def __getitem__(self, item: int) -> T:
|
||||
...
|
||||
def __getitem__(self, item: int) -> T: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> list[T]:
|
||||
...
|
||||
def __getitem__(self, s: slice, /) -> list[T]: ...
|
||||
|
||||
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
|
||||
return self._x[item]
|
||||
|
||||
@overload
|
||||
def __setitem__(self, item: int, value: T):
|
||||
...
|
||||
def __setitem__(self, item: int, value: T): ...
|
||||
|
||||
@overload
|
||||
def __setitem__(self, s: slice, value: T, /):
|
||||
...
|
||||
def __setitem__(self, s: slice, value: T, /): ...
|
||||
|
||||
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
|
||||
raise TypeError("Cannot set item in a constant list")
|
||||
@@ -113,10 +114,7 @@ class CpuGpuBuffer:
|
||||
pin_memory: bool,
|
||||
with_numpy: bool = True,
|
||||
) -> None:
|
||||
self.cpu = torch.zeros(*size,
|
||||
dtype=dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.gpu = torch.zeros_like(self.cpu, device=device)
|
||||
self.np: np.ndarray
|
||||
# To keep type hints simple (avoiding generics and subclasses), we
|
||||
@@ -126,7 +124,8 @@ class CpuGpuBuffer:
|
||||
if dtype == torch.bfloat16:
|
||||
raise ValueError(
|
||||
"Bfloat16 torch tensors cannot be directly cast to a "
|
||||
"numpy array, so call CpuGpuBuffer with with_numpy=False")
|
||||
"numpy array, so call CpuGpuBuffer with with_numpy=False"
|
||||
)
|
||||
self.np = self.cpu.numpy()
|
||||
|
||||
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
|
||||
@@ -142,9 +141,7 @@ class CpuGpuBuffer:
|
||||
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
|
||||
|
||||
|
||||
def get_engine_client_zmq_addr(local_only: bool,
|
||||
host: str,
|
||||
port: int = 0) -> str:
|
||||
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
|
||||
"""Assign a new ZMQ socket address.
|
||||
|
||||
If local_only is True, participants are colocated and so a unique IPC
|
||||
@@ -153,8 +150,11 @@ def get_engine_client_zmq_addr(local_only: bool,
|
||||
Otherwise, the provided host and port will be used to construct a TCP
|
||||
address (port == 0 means assign an available port)."""
|
||||
|
||||
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
|
||||
host, port or get_open_port()))
|
||||
return (
|
||||
get_open_zmq_ipc_path()
|
||||
if local_only
|
||||
else (get_tcp_uri(host, port or get_open_port()))
|
||||
)
|
||||
|
||||
|
||||
class APIServerProcessManager:
|
||||
@@ -195,21 +195,23 @@ class APIServerProcessManager:
|
||||
spawn_context = multiprocessing.get_context("spawn")
|
||||
self.processes: list[BaseProcess] = []
|
||||
|
||||
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
|
||||
output_addresses):
|
||||
for i, in_addr, out_addr in zip(
|
||||
range(num_servers), input_addresses, output_addresses
|
||||
):
|
||||
client_config = {
|
||||
"input_address": in_addr,
|
||||
"output_address": out_addr,
|
||||
"client_count": num_servers,
|
||||
"client_index": i
|
||||
"client_index": i,
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
client_config["stats_update_address"] = stats_update_address
|
||||
|
||||
proc = spawn_context.Process(target=target_server_fn,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args,
|
||||
client_config))
|
||||
proc = spawn_context.Process(
|
||||
target=target_server_fn,
|
||||
name=f"ApiServer_{i}",
|
||||
args=(listen_address, sock, args, client_config),
|
||||
)
|
||||
self.processes.append(proc)
|
||||
proc.start()
|
||||
|
||||
@@ -224,10 +226,12 @@ class APIServerProcessManager:
|
||||
|
||||
|
||||
def wait_for_completion_or_failure(
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Optional[Union["CoreEngineProcManager",
|
||||
"CoreEngineActorManager"]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
api_server_manager: APIServerProcessManager,
|
||||
engine_manager: Optional[
|
||||
Union["CoreEngineProcManager", "CoreEngineActorManager"]
|
||||
] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None,
|
||||
) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
@@ -240,16 +244,14 @@ def wait_for_completion_or_failure(
|
||||
coordinator: The coordinator for data parallel.
|
||||
"""
|
||||
|
||||
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
||||
CoreEngineProcManager)
|
||||
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
|
||||
|
||||
try:
|
||||
logger.info("Waiting for API servers to complete ...")
|
||||
# Create a mapping of sentinels to their corresponding processes
|
||||
# for efficient lookup
|
||||
sentinel_to_proc: dict[Any, BaseProcess] = {
|
||||
proc.sentinel: proc
|
||||
for proc in api_server_manager.processes
|
||||
proc.sentinel: proc for proc in api_server_manager.processes
|
||||
}
|
||||
|
||||
if coordinator:
|
||||
@@ -265,8 +267,7 @@ def wait_for_completion_or_failure(
|
||||
# Check if any process terminates
|
||||
while sentinel_to_proc or actor_run_refs:
|
||||
# Wait for any process to terminate
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
|
||||
timeout=5)
|
||||
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
|
||||
|
||||
# Process any terminated processes
|
||||
for sentinel in ready_sentinels:
|
||||
@@ -276,17 +277,18 @@ def wait_for_completion_or_failure(
|
||||
if proc.exitcode != 0:
|
||||
raise RuntimeError(
|
||||
f"Process {proc.name} (PID: {proc.pid}) "
|
||||
f"died with exit code {proc.exitcode}")
|
||||
f"died with exit code {proc.exitcode}"
|
||||
)
|
||||
|
||||
if actor_run_refs:
|
||||
import ray
|
||||
|
||||
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down API servers...")
|
||||
except Exception as e:
|
||||
logger.exception("Exception occurred while running API servers: %s",
|
||||
str(e))
|
||||
logger.exception("Exception occurred while running API servers: %s", str(e))
|
||||
raise
|
||||
finally:
|
||||
logger.info("Terminating remaining processes ...")
|
||||
@@ -319,8 +321,9 @@ def shutdown(procs: list[BaseProcess]):
|
||||
kill_process_tree(pid)
|
||||
|
||||
|
||||
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
length: int) -> torch.Tensor:
|
||||
def copy_slice(
|
||||
from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Copy the first length elements of a tensor into another tensor in a
|
||||
non-blocking manner.
|
||||
@@ -333,8 +336,8 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
|
||||
|
||||
|
||||
def report_usage_stats(
|
||||
vllm_config,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
|
||||
vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
|
||||
) -> None:
|
||||
"""Report usage statistics if enabled."""
|
||||
|
||||
if not is_usage_stats_enabled():
|
||||
@@ -347,32 +350,21 @@ def report_usage_stats(
|
||||
usage_context,
|
||||
extra_kvs={
|
||||
# Common configuration
|
||||
"dtype":
|
||||
str(vllm_config.model_config.dtype),
|
||||
"tensor_parallel_size":
|
||||
vllm_config.parallel_config.tensor_parallel_size,
|
||||
"block_size":
|
||||
vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization":
|
||||
vllm_config.cache_config.gpu_memory_utilization,
|
||||
"kv_cache_memory_bytes":
|
||||
vllm_config.cache_config.kv_cache_memory_bytes,
|
||||
"dtype": str(vllm_config.model_config.dtype),
|
||||
"tensor_parallel_size": vllm_config.parallel_config.tensor_parallel_size,
|
||||
"block_size": vllm_config.cache_config.block_size,
|
||||
"gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
|
||||
"kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
|
||||
# Quantization
|
||||
"quantization":
|
||||
vllm_config.model_config.quantization,
|
||||
"kv_cache_dtype":
|
||||
str(vllm_config.cache_config.cache_dtype),
|
||||
|
||||
"quantization": vllm_config.model_config.quantization,
|
||||
"kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
|
||||
# Feature flags
|
||||
"enable_lora":
|
||||
bool(vllm_config.lora_config),
|
||||
"enable_prefix_caching":
|
||||
vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager":
|
||||
vllm_config.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce":
|
||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
"enable_lora": bool(vllm_config.lora_config),
|
||||
"enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
|
||||
"enforce_eager": vllm_config.model_config.enforce_eager,
|
||||
"disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_PROFILER_FUNC = None
|
||||
@@ -390,6 +382,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||
func = record_function
|
||||
elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
|
||||
import nvtx
|
||||
|
||||
func = nvtx.annotate
|
||||
|
||||
_PROFILER_FUNC = func
|
||||
|
||||
Reference in New Issue
Block a user