Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,25 +9,35 @@ from collections.abc import Sequence
from contextlib import AbstractContextManager
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Optional,
TypeVar,
Union,
overload,
)
import torch
from torch.autograd.profiler import record_function
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree)
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
from vllm.utils import (
get_open_port,
get_open_zmq_ipc_path,
get_tcp_uri,
kill_process_tree,
)
if TYPE_CHECKING:
import numpy as np
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
logger = init_logger(__name__)
@@ -35,7 +45,6 @@ T = TypeVar("T")
class ConstantList(Generic[T], Sequence):
def __init__(self, x: list[T]) -> None:
self._x = x
@@ -57,31 +66,23 @@ class ConstantList(Generic[T], Sequence):
def clear(self):
raise TypeError("Cannot clear a constant list")
def index(self,
item: T,
start: int = 0,
stop: Optional[int] = None) -> int:
return self._x.index(item, start,
stop if stop is not None else len(self._x))
def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int:
return self._x.index(item, start, stop if stop is not None else len(self._x))
@overload
def __getitem__(self, item: int) -> T:
...
def __getitem__(self, item: int) -> T: ...
@overload
def __getitem__(self, s: slice, /) -> list[T]:
...
def __getitem__(self, s: slice, /) -> list[T]: ...
def __getitem__(self, item: Union[int, slice]) -> Union[T, list[T]]:
return self._x[item]
@overload
def __setitem__(self, item: int, value: T):
...
def __setitem__(self, item: int, value: T): ...
@overload
def __setitem__(self, s: slice, value: T, /):
...
def __setitem__(self, s: slice, value: T, /): ...
def __setitem__(self, item: Union[int, slice], value: Union[T, list[T]]):
raise TypeError("Cannot set item in a constant list")
@@ -113,10 +114,7 @@ class CpuGpuBuffer:
pin_memory: bool,
with_numpy: bool = True,
) -> None:
self.cpu = torch.zeros(*size,
dtype=dtype,
device="cpu",
pin_memory=pin_memory)
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=pin_memory)
self.gpu = torch.zeros_like(self.cpu, device=device)
self.np: np.ndarray
# To keep type hints simple (avoiding generics and subclasses), we
@@ -126,7 +124,8 @@ class CpuGpuBuffer:
if dtype == torch.bfloat16:
raise ValueError(
"Bfloat16 torch tensors cannot be directly cast to a "
"numpy array, so call CpuGpuBuffer with with_numpy=False")
"numpy array, so call CpuGpuBuffer with with_numpy=False"
)
self.np = self.cpu.numpy()
def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor:
@@ -142,9 +141,7 @@ class CpuGpuBuffer:
return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True)
def get_engine_client_zmq_addr(local_only: bool,
host: str,
port: int = 0) -> str:
def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str:
"""Assign a new ZMQ socket address.
If local_only is True, participants are colocated and so a unique IPC
@@ -153,8 +150,11 @@ def get_engine_client_zmq_addr(local_only: bool,
Otherwise, the provided host and port will be used to construct a TCP
address (port == 0 means assign an available port)."""
return get_open_zmq_ipc_path() if local_only else (get_tcp_uri(
host, port or get_open_port()))
return (
get_open_zmq_ipc_path()
if local_only
else (get_tcp_uri(host, port or get_open_port()))
)
class APIServerProcessManager:
@@ -195,21 +195,23 @@ class APIServerProcessManager:
spawn_context = multiprocessing.get_context("spawn")
self.processes: list[BaseProcess] = []
for i, in_addr, out_addr in zip(range(num_servers), input_addresses,
output_addresses):
for i, in_addr, out_addr in zip(
range(num_servers), input_addresses, output_addresses
):
client_config = {
"input_address": in_addr,
"output_address": out_addr,
"client_count": num_servers,
"client_index": i
"client_index": i,
}
if stats_update_address is not None:
client_config["stats_update_address"] = stats_update_address
proc = spawn_context.Process(target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args,
client_config))
proc = spawn_context.Process(
target=target_server_fn,
name=f"ApiServer_{i}",
args=(listen_address, sock, args, client_config),
)
self.processes.append(proc)
proc.start()
@@ -224,10 +226,12 @@ class APIServerProcessManager:
def wait_for_completion_or_failure(
api_server_manager: APIServerProcessManager,
engine_manager: Optional[Union["CoreEngineProcManager",
"CoreEngineActorManager"]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
api_server_manager: APIServerProcessManager,
engine_manager: Optional[
Union["CoreEngineProcManager", "CoreEngineActorManager"]
] = None,
coordinator: Optional["DPCoordinator"] = None,
) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
@@ -240,16 +244,14 @@ def wait_for_completion_or_failure(
coordinator: The coordinator for data parallel.
"""
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager
try:
logger.info("Waiting for API servers to complete ...")
# Create a mapping of sentinels to their corresponding processes
# for efficient lookup
sentinel_to_proc: dict[Any, BaseProcess] = {
proc.sentinel: proc
for proc in api_server_manager.processes
proc.sentinel: proc for proc in api_server_manager.processes
}
if coordinator:
@@ -265,8 +267,7 @@ def wait_for_completion_or_failure(
# Check if any process terminates
while sentinel_to_proc or actor_run_refs:
# Wait for any process to terminate
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc,
timeout=5)
ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5)
# Process any terminated processes
for sentinel in ready_sentinels:
@@ -276,17 +277,18 @@ def wait_for_completion_or_failure(
if proc.exitcode != 0:
raise RuntimeError(
f"Process {proc.name} (PID: {proc.pid}) "
f"died with exit code {proc.exitcode}")
f"died with exit code {proc.exitcode}"
)
if actor_run_refs:
import ray
_, actor_run_refs = ray.wait(actor_run_refs, timeout=5)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down API servers...")
except Exception as e:
logger.exception("Exception occurred while running API servers: %s",
str(e))
logger.exception("Exception occurred while running API servers: %s", str(e))
raise
finally:
logger.info("Terminating remaining processes ...")
@@ -319,8 +321,9 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid)
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
def copy_slice(
from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int
) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
@@ -333,8 +336,8 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
def report_usage_stats(
vllm_config,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
vllm_config, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT
) -> None:
"""Report usage statistics if enabled."""
if not is_usage_stats_enabled():
@@ -347,32 +350,21 @@ def report_usage_stats(
usage_context,
extra_kvs={
# Common configuration
"dtype":
str(vllm_config.model_config.dtype),
"tensor_parallel_size":
vllm_config.parallel_config.tensor_parallel_size,
"block_size":
vllm_config.cache_config.block_size,
"gpu_memory_utilization":
vllm_config.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
vllm_config.cache_config.kv_cache_memory_bytes,
"dtype": str(vllm_config.model_config.dtype),
"tensor_parallel_size": vllm_config.parallel_config.tensor_parallel_size,
"block_size": vllm_config.cache_config.block_size,
"gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
# Quantization
"quantization":
vllm_config.model_config.quantization,
"kv_cache_dtype":
str(vllm_config.cache_config.cache_dtype),
"quantization": vllm_config.model_config.quantization,
"kv_cache_dtype": str(vllm_config.cache_config.cache_dtype),
# Feature flags
"enable_lora":
bool(vllm_config.lora_config),
"enable_prefix_caching":
vllm_config.cache_config.enable_prefix_caching,
"enforce_eager":
vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})
"enable_lora": bool(vllm_config.lora_config),
"enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching,
"enforce_eager": vllm_config.model_config.enforce_eager,
"disable_custom_all_reduce": vllm_config.parallel_config.disable_custom_all_reduce,
},
)
_PROFILER_FUNC = None
@@ -390,6 +382,7 @@ def record_function_or_nullcontext(name: str) -> AbstractContextManager:
func = record_function
elif envs.VLLM_NVTX_SCOPES_FOR_PROFILING:
import nvtx
func = nvtx.annotate
_PROFILER_FUNC = func