Files
vllm/vllm/v1/engine/core_client.py
2026-03-30 10:16:09 -07:00

1667 lines
65 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import queue
import sys
import uuid
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass
from multiprocessing.queues import Queue
from threading import Thread
from typing import Any, TypeAlias, TypeVar
import msgspec.msgpack
import zmq
import zmq.asyncio
from vllm.config import VllmConfig
from vllm.envs import VLLM_ENGINE_READY_TIMEOUT_S
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.async_utils import in_loop
from vllm.utils.network_utils import (
close_sockets,
get_open_zmq_inproc_path,
make_zmq_socket,
)
from vllm.v1.engine import (
EEP_NOTIFICATION_CALL_ID,
EEPNotificationType,
EngineCoreOutputs,
EngineCoreRequest,
EngineCoreRequestType,
PauseMode,
ReconfigureDistributedRequest,
ReconfigureRankType,
UtilityOutput,
)
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.engine.exceptions import EngineDeadError
from vllm.v1.engine.tensor_ipc import TensorIpcSender
from vllm.v1.engine.utils import (
CoreEngineActorManager,
CoreEngineProcManager,
get_engine_zmq_addresses,
launch_core_engines,
)
from vllm.v1.executor import Executor
from vllm.v1.pool.late_interaction import get_late_interaction_engine_index
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
logger = init_logger(__name__)
AnyFuture: TypeAlias = asyncio.Future[Any] | Future[Any]
_R = TypeVar("_R") # Return type for collective_rpc
EngineIdentity = bytes
class EngineCoreClient(ABC):
"""
EngineCoreClient: subclasses handle different methods for pushing
and pulling from the EngineCore for asyncio / multiprocessing.
Subclasses:
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
"""
@staticmethod
def make_client(
multiprocess_mode: bool,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
) -> "EngineCoreClient":
# TODO: support this for debugging purposes.
if asyncio_mode and not multiprocess_mode:
raise NotImplementedError(
"Running EngineCore in asyncio without multiprocessing "
"is not currently supported."
)
if multiprocess_mode and asyncio_mode:
return EngineCoreClient.make_async_mp_client(
vllm_config, executor_class, log_stats
)
if multiprocess_mode and not asyncio_mode:
return SyncMPClient(vllm_config, executor_class, log_stats)
return InprocClient(vllm_config, executor_class, log_stats)
@staticmethod
@instrument(span_name="Overall Loading")
def make_async_mp_client(
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "AsyncMPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
if parallel_config.data_parallel_size > 1:
if parallel_config.data_parallel_external_lb:
# External load balancer - client per DP rank.
return DPAsyncMPClient(*client_args)
# Internal load balancer - client balances to all DP ranks.
return DPLBAsyncMPClient(*client_args)
return AsyncMPClient(*client_args)
@abstractmethod
def shutdown(self, timeout: float | None = None) -> None: ...
def get_output(self) -> EngineCoreOutputs:
raise NotImplementedError
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
def add_request(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
raise NotImplementedError
def reset_mm_cache(self) -> None:
raise NotImplementedError
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
def reset_encoder_cache(self) -> None:
raise NotImplementedError
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError
def wake_up(self, tags: list[str] | None = None) -> None:
raise NotImplementedError
def is_sleeping(self) -> bool:
raise NotImplementedError
def execute_dummy_batch(self) -> None:
raise NotImplementedError
async def execute_dummy_batch_async(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: list[str]) -> None:
raise NotImplementedError
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def list_loras(self) -> set[int]:
raise NotImplementedError
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
def save_sharded_state(
self, path: str, pattern: str | None = None, max_size: int | None = None
) -> None:
raise NotImplementedError
def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
raise NotImplementedError
def dp_engines_running(self) -> bool:
"""Returns True if data parallel engines are collectively in a
running state."""
raise NotImplementedError
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
raise NotImplementedError
async def get_output_async(self) -> EngineCoreOutputs:
raise NotImplementedError
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
raise NotImplementedError
async def add_request_async(self, request: EngineCoreRequest) -> None:
raise NotImplementedError
async def profile_async(
self, is_start: bool = True, profile_prefix: str | None = None
) -> None:
raise NotImplementedError
async def reset_mm_cache_async(self) -> None:
raise NotImplementedError
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
async def reset_encoder_cache_async(self) -> None:
raise NotImplementedError
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
raise NotImplementedError
async def wake_up_async(self, tags: list[str] | None = None) -> None:
raise NotImplementedError
async def is_sleeping_async(self) -> bool:
raise NotImplementedError
async def abort_requests_async(self, request_ids: list[str]) -> None:
raise NotImplementedError
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
async def remove_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def list_loras_async(self) -> set[int]:
raise NotImplementedError
async def pin_lora_async(self, lora_id: int) -> bool:
raise NotImplementedError
async def save_sharded_state_async(
self, path: str, pattern: str | None = None, max_size: int | None = None
) -> None:
raise NotImplementedError
async def collective_rpc_async(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
raise NotImplementedError
class InprocClient(EngineCoreClient):
"""
InprocClient: client for in-process EngineCore. Intended
for use in LLMEngine for V0-style add_request() and step()
EngineCore setup in this process (no busy loop).
* pushes EngineCoreRequest directly into the EngineCore
* pulls EngineCoreOutputs by stepping the EngineCore
"""
def __init__(self, *args, **kwargs):
self.engine_core = EngineCore(*args, **kwargs)
def get_output(self) -> EngineCoreOutputs:
outputs, model_executed = self.engine_core.step_fn()
self.engine_core.post_step(model_executed=model_executed)
return outputs and outputs.get(0) or EngineCoreOutputs()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.engine_core.get_supported_tasks()
def add_request(self, request: EngineCoreRequest) -> None:
req, request_wave = self.engine_core.preprocess_add_request(request)
self.engine_core.add_request(req, request_wave)
def abort_requests(self, request_ids: list[str]) -> None:
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
def shutdown(self, timeout: float | None = None) -> None:
self.engine_core.shutdown()
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
self.engine_core.profile(is_start, profile_prefix)
def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
self.engine_core.reset_encoder_cache()
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
if mode == "wait":
raise ValueError("'wait' pause mode is not supported in inproc-engine mode")
result = self.engine_core.sleep(level, mode)
assert result is None
def wake_up(self, tags: list[str] | None = None) -> None:
self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def execute_dummy_batch(self) -> None:
self.engine_core.execute_dummy_batch()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.engine_core.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.engine_core.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.engine_core.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.engine_core.pin_lora(lora_id)
def save_sharded_state(
self, path: str, pattern: str | None = None, max_size: int | None = None
) -> None:
self.engine_core.save_sharded_state(path, pattern, max_size)
def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
def dp_engines_running(self) -> bool:
return False
@dataclass
class BackgroundResources:
"""Used as a finalizer for clean shutdown, avoiding
circular reference back to the client object."""
ctx: zmq.Context
# If CoreEngineProcManager, it manages local engines;
# if CoreEngineActorManager, it manages all engines.
engine_manager: CoreEngineProcManager | CoreEngineActorManager | None = None
coordinator: DPCoordinator | None = None
output_socket: zmq.Socket | zmq.asyncio.Socket | None = None
input_socket: zmq.Socket | zmq.asyncio.Socket | None = None
first_req_send_socket: zmq.asyncio.Socket | None = None
first_req_rcv_socket: zmq.asyncio.Socket | None = None
stats_update_socket: zmq.asyncio.Socket | None = None
output_queue_task: asyncio.Task | None = None
stats_update_task: asyncio.Task | None = None
shutdown_path: str | None = None
# Set if any of the engines are dead. Here so that the output
# processing threads can access it without holding a ref to the client.
engine_dead: bool = False
def __call__(self):
"""Clean up background resources."""
self.engine_dead = True
if self.engine_manager is not None:
self.engine_manager.shutdown()
if self.coordinator is not None:
self.coordinator.shutdown()
if isinstance(self.output_socket, zmq.asyncio.Socket):
# Async case.
loop = self.output_queue_task._loop if self.output_queue_task else None
sockets = (
self.output_socket,
self.input_socket,
self.first_req_send_socket,
self.first_req_rcv_socket,
self.stats_update_socket,
)
tasks = (self.output_queue_task, self.stats_update_task)
def close_sockets_and_tasks():
close_sockets(sockets)
for task in tasks:
if task is not None and not task.done():
with contextlib.suppress(Exception):
task.cancel()
if loop is not None:
if in_loop(loop):
close_sockets_and_tasks()
elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
else:
# Loop has been closed, try to clean up directly.
del tasks
del close_sockets_and_tasks
close_sockets(sockets)
del self.output_queue_task
del self.stats_update_task
else:
# Sync case.
# ZMQ context termination can hang if the sockets
# aren't explicitly closed first.
close_sockets((self.output_socket, self.input_socket))
if self.shutdown_path is not None:
# We must ensure that the sync output socket is
# closed cleanly in its own thread.
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
shutdown_sender.connect(self.shutdown_path)
# Send shutdown signal.
shutdown_sender.send(b"")
def validate_alive(self, frames: Sequence[zmq.Frame]):
if len(frames) == 1 and (frames[0].buffer == EngineCoreProc.ENGINE_CORE_DEAD):
self.engine_dead = True
raise EngineDeadError()
@dataclass
class ElasticScalingCache:
existing_core_engines: list[EngineIdentity]
num_new_core_engines: int
pending_notifications: dict[EEPNotificationType, set[int]]
class MPClient(EngineCoreClient):
"""
MPClient: base client for multi-proc EngineCore.
EngineCore runs in a background process busy loop, getting
new EngineCoreRequests and returning EngineCoreOutputs
* pushes EngineCoreRequests via input_socket
* pulls EngineCoreOutputs via output_socket
* AsyncMPClient subclass for AsyncLLM usage
* SyncMPClient subclass for LLM usage
"""
def __init__(
self,
asyncio_mode: bool,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
):
self.vllm_config = vllm_config
# ZMQ setup.
sync_ctx = zmq.Context(io_threads=2)
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
# This will ensure resources created so far are closed
# when the client is garbage collected, even if an
# exception is raised mid-construction.
self.resources = BackgroundResources(ctx=sync_ctx)
self._finalizer = weakref.finalize(self, self.resources)
success = False
try:
# State used for data parallel.
self.engines_running = False
parallel_config = vllm_config.parallel_config
# Elastic EP can remove a rank and later add it back with the same
# identity. The client input ROUTER needs handover to allow the new
# engine to replace the dead connection.
enable_input_socket_handover = parallel_config.enable_elastic_ep
self.stats_update_address: str | None = None
tensor_queue: Queue | None = None
if client_addresses:
# Engines are managed externally to this client.
input_address = client_addresses["input_address"]
output_address = client_addresses["output_address"]
self.stats_update_address = client_addresses.get("stats_update_address")
# Tensor queues passed via client_addresses for multi-API-server case
tensor_queue = client_addresses.get("tensor_queue") # type: ignore[assignment]
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx,
input_address,
zmq.ROUTER,
bind=True,
router_handover=enable_input_socket_handover,
)
self.resources.output_socket = make_zmq_socket(
self.ctx, output_address, zmq.PULL
)
else:
# Engines are managed by this client.
addresses = get_engine_zmq_addresses(vllm_config)
self.input_socket = self.resources.input_socket = make_zmq_socket(
self.ctx,
addresses.inputs[0],
zmq.ROUTER,
bind=True,
router_handover=enable_input_socket_handover,
)
self.resources.output_socket = make_zmq_socket(
self.ctx, addresses.outputs[0], zmq.PULL
)
with launch_core_engines(
vllm_config, executor_class, log_stats, addresses
) as (engine_manager, coordinator, addresses, tensor_queue):
self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager
self.stats_update_address = addresses.frontend_stats_publish_address
if coordinator is not None:
assert self.stats_update_address == (
coordinator.get_stats_publish_address()
)
# Serialization setup with tensor queues for multimodal tensor IPC.
tensor_ipc_sender: TensorIpcSender | None = None
model_config = getattr(vllm_config, "model_config", None)
if model_config is not None and model_config.multimodal_config is not None:
mm_tensor_ipc = model_config.multimodal_config.mm_tensor_ipc
if mm_tensor_ipc == "torch_shm" and tensor_queue is not None:
tensor_ipc_sender = TensorIpcSender(tensor_queue)
self.encoder = MsgpackEncoder(oob_tensor_consumer=tensor_ipc_sender)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_index
dp_local_size = parallel_config.data_parallel_size_local
offline_mode = parallel_config.data_parallel_rank_local is not None
# Client manages local+remote EngineCores in pure internal LB case.
# Client manages local EngineCores in hybrid and external LB case.
num_ranks = dp_local_size if parallel_config.local_engines_only else dp_size
self.engine_ranks_managed = (
[dp_rank] if offline_mode else list(range(dp_rank, dp_rank + num_ranks))
)
assert parallel_config.data_parallel_size_local <= len(
self.engine_ranks_managed
)
# ZMQ identity of each engine that this client will talk to.
self.core_engines: list[EngineIdentity] = [
rank.to_bytes(2, "little") for rank in self.engine_ranks_managed
]
# Wait for ready messages from each engine on the input socket.
identities = set(self.core_engines)
sync_input_socket = zmq.Socket.shadow(self.input_socket)
while identities:
if not sync_input_socket.poll(
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
):
raise TimeoutError(
f"Timed out waiting for engine core processes to "
f"start. This is often caused by slow weight loading "
f"for large models. Waited "
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
f"timeout, set the environment variable: "
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
)
identity, _ = sync_input_socket.recv_multipart()
identities.remove(identity)
self.core_engine: EngineIdentity = self.core_engines[0]
self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors
# that we need to keep references to until zmq is done with the
# underlying data.
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
# Start monitoring engine core processes for unexpected failures
self.start_engine_core_monitor()
success = True
finally:
if not success:
self._finalizer()
def shutdown(self, timeout: float | None = None) -> None:
"""Shutdown engine manager under timeout and clean up resources."""
if self._finalizer.detach() is not None:
if self.resources.engine_manager is not None:
self.resources.engine_manager.shutdown(timeout=timeout)
self.resources()
def _format_exception(self, e: Exception) -> Exception:
"""If errored, use EngineDeadError so root cause is clear."""
return (
EngineDeadError(suppress_context=True) if self.resources.engine_dead else e
)
def ensure_alive(self):
if self.resources.engine_dead:
raise EngineDeadError()
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
if not tracker.done:
self.pending_messages.appendleft((tracker, msg))
def free_pending_messages(self):
while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop()
def dp_engines_running(self) -> bool:
return self.engines_running
def start_engine_core_monitor(self):
"""Start a monitor thread for engine core processes."""
engine_manager = self.resources.engine_manager
if engine_manager is None:
# No engine processes to monitor
return
self_ref = weakref.ref(self)
# Monitor engine core process liveness. If any die unexpectedly,
# marks the engine as dead, and shuts down the client.
def monitor_engine_cores():
engine_manager.monitor_engine_liveness()
_self = self_ref()
if not _self or not _self._finalizer.alive or _self.resources.engine_dead:
return
_self.resources.engine_dead = True
_self.shutdown()
# Note: For MPClient, we don't have a failure callback mechanism
# like MultiprocExecutor, but we set engine_dead flag which will
# cause subsequent operations to raise EngineDeadError
Thread(
target=monitor_engine_cores, daemon=True, name="MPClientEngineMonitor"
).start()
def _process_utility_output(
output: UtilityOutput, utility_results: dict[int, AnyFuture]
):
"""Set the result from a utility method in the waiting future."""
future = utility_results.pop(output.call_id)
failure_message = output.failure_message
try:
if failure_message is not None:
future.set_exception(Exception(failure_message))
else:
assert output.result is not None
future.set_result(output.result.result)
except asyncio.InvalidStateError:
# This can happen if the future is cancelled due to the
# original calling task being cancelled.
if failure_message is not None:
logger.error(
"Cancelled call to utility method failed with error: %s",
failure_message,
)
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
@instrument(span_name="SyncMPClient init")
def __init__(
self, vllm_config: VllmConfig, executor_class: type[Executor], log_stats: bool
):
super().__init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
)
self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
self.outputs_queue = queue.Queue[EngineCoreOutputs | Exception]()
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
ctx = self.ctx
out_socket = self.resources.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
shutdown_path = get_open_zmq_inproc_path()
resources = self.resources
resources.shutdown_path = shutdown_path
def process_outputs_socket():
assert isinstance(out_socket, zmq.Socket)
shutdown_socket = ctx.socket(zmq.PAIR)
try:
shutdown_socket.bind(shutdown_path)
poller = zmq.Poller()
poller.register(shutdown_socket, zmq.POLLIN)
poller.register(out_socket, zmq.POLLIN)
while True:
socks = poller.poll()
if not socks:
continue
if len(socks) == 2 or socks[0][0] == shutdown_socket:
# shutdown signal, exit thread.
break
frames = out_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
_process_utility_output(outputs.utility_output, utility_results)
else:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
finally:
# Close sockets.
shutdown_socket.close(linger=0)
out_socket.close(linger=0)
# Process outputs from engine in separate thread.
self.output_queue_thread = Thread(
target=process_outputs_socket,
name="EngineCoreOutputQueueThread",
daemon=True,
)
self.output_queue_thread.start()
# The thread takes on responsibility for closing the socket.
self.resources.output_socket = None
def get_output(self) -> EngineCoreOutputs:
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
outputs = self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
if outputs.wave_complete is not None:
self.engines_running = False
return outputs
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
self.ensure_alive()
self.free_pending_messages()
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine, request_type.value, *self.encoder.encode(request))
if len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
self.input_socket.send_multipart(msg, copy=False)
return
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
self.add_pending_message(tracker, request)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY, (0, call_id, method, args))
return future.result()
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.call_utility("get_supported_tasks")
def add_request(self, request: EngineCoreRequest) -> None:
if self.is_dp:
self.engines_running = True
self._send_input(EngineCoreRequestType.ADD, request)
def abort_requests(self, request_ids: list[str]) -> None:
if request_ids and not self.resources.engine_dead:
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None:
self.call_utility("profile", is_start, profile_prefix)
def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.call_utility(
"reset_prefix_cache", reset_running_requests, reset_connector
)
def reset_encoder_cache(self) -> None:
self.call_utility("reset_encoder_cache")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.call_utility("remove_lora", lora_id)
def list_loras(self) -> set[int]:
return self.call_utility("list_loras")
def pin_lora(self, lora_id: int) -> bool:
return self.call_utility("pin_lora", lora_id)
def sleep(self, level: int = 1, mode: PauseMode = "abort") -> None:
self.call_utility("sleep", level, mode)
def wake_up(self, tags: list[str] | None = None) -> None:
self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool:
return self.call_utility("is_sleeping")
def execute_dummy_batch(self) -> None:
self.call_utility("execute_dummy_batch")
def collective_rpc(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
return self.call_utility("collective_rpc", method, timeout, args, kwargs)
def save_sharded_state(
self, path: str, pattern: str | None = None, max_size: int | None = None
) -> None:
self.call_utility("save_sharded_state", path, pattern, max_size)
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
@instrument(span_name="AsyncMPClient init")
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
):
super().__init__(
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=log_stats,
client_addresses=client_addresses,
)
self.client_count = client_count
self.client_index = client_index
self.outputs_queue = asyncio.Queue[EngineCoreOutputs | Exception]()
try:
# If we are running in an asyncio event loop, start the queue task.
# Otherwise, it will be started lazily. If it is not started here,
# we could miss EXECUTOR_FAILED messages from engine core if they
# occur prior to any requests being sent.
asyncio.get_running_loop()
self._ensure_output_queue_task()
except RuntimeError:
pass
def _ensure_output_queue_task(self):
resources = self.resources
if resources.output_queue_task is not None:
return
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
output_handler: (
Callable[[AsyncMPClient, EngineCoreOutputs], Awaitable[None]] | None
) = getattr(self.__class__, "process_engine_outputs", None)
_self_ref = weakref.ref(self) if output_handler else None
output_socket = resources.output_socket
assert output_socket is not None
notification_callback_handler: (
Callable[[AsyncMPClient, Sequence[Any]], Any] | None
) = getattr(self.__class__, "eep_process_engine_core_notification", None)
async def process_outputs_socket():
try:
while True:
frames = await output_socket.recv_multipart(copy=False)
resources.validate_alive(frames)
outputs: EngineCoreOutputs = decoder.decode(frames)
if outputs.utility_output:
if (
outputs.utility_output.call_id == EEP_NOTIFICATION_CALL_ID
and notification_callback_handler is not None
):
assert _self_ref is not None
_self = _self_ref()
if not _self:
return
if outputs.utility_output.result is None:
continue
notification_data = outputs.utility_output.result.result
assert isinstance(notification_data, Sequence)
assert len(notification_data) == 2
asyncio.create_task(
notification_callback_handler(_self, notification_data)
)
else:
_process_utility_output(
outputs.utility_output, utility_results
)
continue
if output_handler is not None:
assert _self_ref is not None
_self = _self_ref()
if not _self:
# Client has been garbage collected, abort.
return
await output_handler(_self, outputs)
if outputs.outputs or outputs.scheduler_stats:
outputs_queue.put_nowait(outputs)
except Exception as e:
outputs_queue.put_nowait(e)
except asyncio.CancelledError:
outputs_queue.put_nowait(EngineDeadError())
resources.output_queue_task = asyncio.create_task(
process_outputs_socket(), name="EngineCoreOutputQueueTask"
)
async def get_output_async(self) -> EngineCoreOutputs:
self._ensure_output_queue_task()
# If an exception arises in process_outputs_socket task,
# it is forwarded to the outputs_queue so we can raise it
# from this (run_output_handler) task to shut down the server.
assert self.outputs_queue is not None
outputs = await self.outputs_queue.get()
if isinstance(outputs, Exception):
raise self._format_exception(outputs) from None
return outputs
def _send_input(
self,
request_type: EngineCoreRequestType,
request: Any,
engine: EngineIdentity | None = None,
) -> Awaitable[Any]:
if engine is None:
engine = self.core_engine
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine, request)
def _send_input_message(
self, message: tuple[bytestr, ...], engine: EngineIdentity, objects: Any
) -> Awaitable[Any]:
"""
objects is a reference to retain until zmq is finished with the
buffers, in case they were extracted from tensors in the request.
"""
self.ensure_alive()
self.free_pending_messages()
msg = (engine,) + message
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
future: asyncio.Future[zmq.MessageTracker]
future = self.input_socket.send_multipart(msg, copy=False, track=True)
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
with contextlib.suppress(BaseException):
self.add_pending_message(f.result(), objects)
future.add_done_callback(add_pending)
return future
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method, *args, engine=self.core_engine)
async def _call_utility_async(
self, method: str, *args, engine: EngineIdentity
) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
message = (
EngineCoreRequestType.UTILITY.value,
*self.encoder.encode((self.client_index, call_id, method, args)),
)
await self._send_input_message(message, engine, args)
self._ensure_output_queue_task()
return await future
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
return await self.call_utility_async("get_supported_tasks")
async def add_request_async(self, request: EngineCoreRequest) -> None:
request.client_index = self.client_index
await self._send_input(EngineCoreRequestType.ADD, request)
self._ensure_output_queue_task()
async def abort_requests_async(self, request_ids: list[str]) -> None:
if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(
self, mode: PauseMode = "abort", clear_cache: bool = True
) -> None:
await self.call_utility_async("pause_scheduler", mode, clear_cache)
async def resume_scheduler_async(self) -> None:
await self.call_utility_async("resume_scheduler")
async def is_scheduler_paused_async(self) -> bool:
return await self.call_utility_async("is_scheduler_paused")
async def profile_async(
self, is_start: bool = True, profile_prefix: str | None = None
) -> None:
await self.call_utility_async("profile", is_start, profile_prefix)
async def reset_mm_cache_async(self) -> None:
await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.call_utility_async(
"reset_prefix_cache", reset_running_requests, reset_connector
)
async def reset_encoder_cache_async(self) -> None:
await self.call_utility_async("reset_encoder_cache")
async def sleep_async(self, level: int = 1, mode: PauseMode = "abort") -> None:
await self.call_utility_async("sleep", level, mode)
async def wake_up_async(self, tags: list[str] | None = None) -> None:
await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool:
return await self.call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None:
await self.call_utility_async("execute_dummy_batch")
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
return await self.call_utility_async("add_lora", lora_request)
async def remove_lora_async(self, lora_id: int) -> bool:
return await self.call_utility_async("remove_lora", lora_id)
async def list_loras_async(self) -> set[int]:
return await self.call_utility_async("list_loras")
async def pin_lora_async(self, lora_id: int) -> bool:
return await self.call_utility_async("pin_lora", lora_id)
async def save_sharded_state_async(
self, path: str, pattern: str | None = None, max_size: int | None = None
) -> None:
await self.call_utility_async("save_sharded_state", path, pattern, max_size)
async def collective_rpc_async(
self,
method: str | Callable[..., _R],
timeout: float | None = None,
args: tuple = (),
kwargs: dict[str, Any] | None = None,
) -> list[_R]:
return await self.call_utility_async(
"collective_rpc", method, timeout, args, kwargs
)
class DPAsyncMPClient(AsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore. Assumes external load-balancing by default."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
):
self.current_wave = 0
super().__init__(
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
# List of [waiting, running] pair per engine.
# Used only by DPLBAsyncMPClient subclass.
self.lb_engines: list[list[int]] = [[0, 0] for _ in self.core_engines]
self.eep_scaling_cache: ElasticScalingCache | None = None
self.first_req_sock_addr = get_open_zmq_inproc_path()
self.first_req_send_socket = self.resources.first_req_send_socket = (
make_zmq_socket(self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=True)
)
try:
# If we are running in an asyncio event loop, start the stats task.
# Otherwise, it will be started lazily.
asyncio.get_running_loop()
self._ensure_stats_update_task()
except RuntimeError:
pass
def _ensure_stats_update_task(self):
resources = self.resources
if resources.stats_update_task is not None:
return
assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
async def run_engine_stats_update_task():
with (
make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
make_zmq_socket(
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
) as first_req_rcv_socket,
):
assert isinstance(socket, zmq.asyncio.Socket)
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
self.resources.stats_update_socket = socket
self.resources.first_req_rcv_socket = first_req_rcv_socket
# Send subscription message.
await socket.send(b"\x01")
poller = zmq.asyncio.Poller()
poller.register(socket, zmq.POLLIN)
poller.register(first_req_rcv_socket, zmq.POLLIN)
while True:
events = await poller.poll()
if (
not self.engines_running
and len(events) == 2
or (events[0][0] == first_req_rcv_socket)
):
# Check if this is a regular request notification or
# scale up notification
buf = first_req_rcv_socket.recv(flags=zmq.NOBLOCK).result()
decoded = msgspec.msgpack.decode(buf)
if (
isinstance(decoded, (list, tuple))
and len(decoded) == 2
and decoded[0] == "SCALE_ELASTIC_EP"
):
# Extract new engine count from the decoded message
new_engine_count = decoded[1]
# Update engine_ranks_managed and count_slice
parallel_config = self.vllm_config.parallel_config
dp_size = parallel_config.data_parallel_size
dp_rank = parallel_config.data_parallel_rank
assert dp_rank == 0
assert dp_size == new_engine_count
assert not (
parallel_config.data_parallel_hybrid_lb
or parallel_config.data_parallel_external_lb
)
num_ranks = dp_size
self.engine_ranks_managed = list(
range(dp_rank, dp_rank + num_ranks)
)
if len(self.lb_engines) < new_engine_count:
self.lb_engines = self.lb_engines + [
[0, 0]
for _ in range(
new_engine_count - len(self.lb_engines)
)
]
else:
self.lb_engines = self.lb_engines[:new_engine_count]
# Send scale up notification to coordinator
scale_msg = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_engine_count)
)
await socket.send(scale_msg)
continue
# we're sending a request while the engines are
# paused, so that it can wake the others up
# (to run dummy EP loop).
assert decoded[0] == "FIRST_REQ"
target_eng_index = decoded[1]
self.engines_running = True
msg = msgspec.msgpack.encode(
(target_eng_index, self.current_wave)
)
await socket.send(msg)
buf = None
while True:
# Drain all stats events (we only care about latest).
future: asyncio.Future[bytes] = socket.recv(flags=zmq.NOBLOCK)
if isinstance(future.exception(), zmq.Again):
break
buf = future.result()
if buf is None:
continue
# Update local load-balancing state.
counts, wave, running = msgspec.msgpack.decode(buf)
self.current_wave = wave
self.engines_running = running
if counts is not None:
# Running and waiting counts are global from the
# Coordinator including all EngineCores. Slice to get
# just the cores managed by this client.
ranks = self.engine_ranks_managed
count_slice = slice(ranks[0], ranks[-1] + 1)
sliced_counts = counts[count_slice]
self.lb_engines = sliced_counts
logger.debug(
"Received counts: %s (%s)", sliced_counts, count_slice
)
resources.stats_update_task = asyncio.create_task(
run_engine_stats_update_task()
)
async def add_request_async(self, request: EngineCoreRequest) -> None:
self._ensure_stats_update_task()
request.current_wave = self.current_wave
request.client_index = self.client_index
chosen_engine = self.get_core_engine_for_request(request)
to_await = self._send_input(EngineCoreRequestType.ADD, request, chosen_engine)
if not self.engines_running:
# Notify coordinator that we're sending a request
req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine))
await self.first_req_send_socket.send(req_msg)
await to_await
self._ensure_output_queue_task()
def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine
class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
EngineCore. Load-balances between multiple engine processes."""
def __init__(
self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
):
self.client_count = client_count
# To route aborts to the correct engine.
self.reqs_in_flight: dict[str, EngineIdentity] = {}
super().__init__(
vllm_config,
executor_class,
log_stats,
client_addresses,
client_count,
client_index,
)
assert len(self.core_engines) > 1
self.eng_start_index = (
len(self.core_engines) * self.client_index
) // client_count
def get_core_engine_for_request(self, request: EngineCoreRequest) -> EngineIdentity:
# Engines are in rank order.
if (eng_index := request.data_parallel_rank) is None and (
eng_index := get_late_interaction_engine_index(
request.pooling_params, len(self.core_engines)
)
) is None:
current_counts = self.lb_engines
# TODO use P2C alg for larger DP sizes
num_engines = len(current_counts)
min_score = sys.maxsize
eng_index = 0
for i in range(num_engines):
# Start from client_index to help with balancing when engines
# are empty.
idx = (self.eng_start_index + i) % num_engines
waiting, running = current_counts[idx]
score = waiting * 4 + running
if score < min_score:
min_score = score
eng_index = idx
# Increment local waiting count for better balancing between stats
# updates from the coordinator (which happen every 100ms).
current_counts[eng_index][0] += self.client_count
chosen_engine = self.core_engines[eng_index]
# Record which engine is chosen for this request, to handle aborts.
self.reqs_in_flight[request.request_id] = chosen_engine
return chosen_engine
async def call_utility_async(self, method: str, *args) -> Any:
# Only the result from the first engine is returned.
return (
await asyncio.gather(
*[
self._call_utility_async(method, *args, engine=engine)
for engine in self.core_engines
]
)
)[0]
@staticmethod
async def process_engine_outputs(
self: "DPLBAsyncMPClient", outputs: EngineCoreOutputs
):
if outputs.finished_requests and self.reqs_in_flight:
for req_id in outputs.finished_requests:
self.reqs_in_flight.pop(req_id, None)
@staticmethod
async def eep_process_engine_core_notification(
self: "DPLBAsyncMPClient", notification_data: tuple[str, int]
):
cache = self.eep_scaling_cache
notification_type_str, dp_rank = notification_data
try:
notification_type = EEPNotificationType(notification_type_str)
except ValueError as e:
raise ValueError(
f"Unknown EEP notification type: {notification_type_str}"
) from e
if notification_type == EEPNotificationType.RECONFIGURE_FINISHED:
from vllm.v1.engine import UtilityResult
# NOTE(yongji): process a dummy UtilityOutput to resolve the future
# awaited in _eep_wait_for_setup_switch_complete(), signaling that
# all engine cores have completed reconfiguration.
dummy_output = UtilityOutput(
call_id=EEP_NOTIFICATION_CALL_ID, result=UtilityResult(None)
)
_process_utility_output(dummy_output, self.utility_results)
return
assert cache is not None
if notification_type not in cache.pending_notifications:
cache.pending_notifications[notification_type] = set()
if dp_rank in cache.pending_notifications[notification_type]:
raise ValueError(
f"Duplicate notification {notification_type} from dp_rank {dp_rank}"
)
cache.pending_notifications[notification_type].add(dp_rank)
if len(cache.pending_notifications[notification_type]) >= abs(
cache.num_new_core_engines
):
if notification_type == EEPNotificationType.SHUTDOWN_COMPLETE:
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
assert cache.num_new_core_engines < 0
old_dp_size = len(cache.existing_core_engines)
new_dp_size = old_dp_size + cache.num_new_core_engines
self.resources.engine_manager.scale_down_elastic_ep(
old_dp_size, new_dp_size
)
else:
await asyncio.gather(
*[
self._call_utility_async(
"eep_handle_engine_core_notification",
notification_type,
engine=engine,
)
for engine in cache.existing_core_engines
]
)
cache.pending_notifications[notification_type] = set()
if notification_type in [
EEPNotificationType.SHUTDOWN_COMPLETE,
EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY,
]:
self.eep_scaling_cache = None
async def abort_requests_async(self, request_ids: list[str]) -> None:
if not request_ids or self.resources.engine_dead:
return
if len(request_ids) == 1:
# Fast-path common case.
if engine := self.reqs_in_flight.get(request_ids[0]):
await self._abort_requests(request_ids, engine)
return
by_engine = defaultdict[EngineIdentity, list[str]](list)
for req_id in request_ids:
if engine := self.reqs_in_flight.get(req_id):
by_engine[engine].append(req_id)
for engine, req_ids in by_engine.items():
await self._abort_requests(req_ids, engine)
async def _abort_requests(
self, request_ids: list[str], engine: EngineIdentity
) -> None:
await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine)
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
"""Scale elastic EP data parallel size"""
cur_data_parallel_size = len(self.core_engines)
assert new_data_parallel_size != cur_data_parallel_size, (
f"new_data_parallel_size {new_data_parallel_size} must be "
f"different from cur_data_parallel_size {cur_data_parallel_size}"
)
assert self.vllm_config.parallel_config.data_parallel_backend == "ray", (
"Only ray DP backend supports scaling elastic EP"
)
scale_up = new_data_parallel_size > cur_data_parallel_size
if scale_up:
await self._scale_up_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
else:
await self._scale_down_elastic_ep(
cur_data_parallel_size, new_data_parallel_size
)
async def _eep_wait_for_setup_switch_complete(self) -> None:
"""
Wait for core engines to switch to the new setup.
In eep_process_engine_core_notification(), a dummy UtilityOutput with
EEP_NOTIFICATION_CALL_ID will be set when RECONFIGURE_FINISHED
notification is received from engine 0. We create a future with
that call_id and wait for it to be resolved.
"""
future = asyncio.get_running_loop().create_future()
self.utility_results[EEP_NOTIFICATION_CALL_ID] = future
self._ensure_output_queue_task()
await future
def _setup_elastic_ep_reconfig_bootstrap(self) -> tuple[str, int]:
from vllm.distributed.utils import create_tcp_store
from vllm.utils.network_utils import get_open_ports_list
parallel_config = self.vllm_config.parallel_config
parallel_config._data_parallel_master_port_list = get_open_ports_list(5)
parallel_config.data_parallel_master_port = (
parallel_config._data_parallel_master_port_list.pop()
)
ip = parallel_config.data_parallel_master_ip
store = create_tcp_store(
ip,
0,
is_master=True,
world_size=-1,
wait_for_workers=False,
)
parallel_config._coord_store_port = store.port
self._coord_store = store
return ip, store.port
async def _scale_up_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
"""Scale up the data parallel size by creating new engine cores
and reconfiguring existing ones."""
cur_data_parallel_size = len(self.core_engines)
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
# Phase 1: Send reconfig messages to existing engines
reconfig_futures = []
for engine in self.core_engines:
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
coord_store_port=coord_store_port,
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
)
reconfig_futures.append(asyncio.create_task(coro))
# Phase 2: Create new engines
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
parallel_config.eplb_config.num_redundant_experts = 0
start_new_worker_future = asyncio.to_thread(
self.resources.engine_manager.scale_up_elastic_ep,
self.vllm_config,
new_data_parallel_size,
)
wait_future = self._eep_wait_for_setup_switch_complete()
# Phase 3: Wait for new engines to be created
# and reconfig messages to be received
await asyncio.gather(start_new_worker_future, *reconfig_futures)
logger.info("[Elastic EP] Successfully started new engines")
# Create new CoreEngine objects for the new engines
new_engine_identities = set()
for i in range(cur_data_parallel_size, new_data_parallel_size):
new_engine = i.to_bytes(2, "little")
self.core_engines.append(new_engine)
# NOTE(yongji): we don't update lb_engines here,
# we let run_engine_stats_update_task to update it.
new_engine_identities.add(new_engine)
# Wait for ready messages from new engines on the input socket
sync_input_socket = zmq.Socket.shadow(self.input_socket)
while new_engine_identities:
if not sync_input_socket.poll(
timeout=VLLM_ENGINE_READY_TIMEOUT_S * 1000 # convert to ms
):
raise TimeoutError(
f"Timed out waiting for new engine core processes to "
f"start. Waited "
f"{VLLM_ENGINE_READY_TIMEOUT_S}s (configured by "
f"VLLM_ENGINE_READY_TIMEOUT_S). To increase the "
f"timeout, set the environment variable: "
f"VLLM_ENGINE_READY_TIMEOUT_S=<seconds>"
)
identity, _ = sync_input_socket.recv_multipart()
new_engine_identities.discard(identity)
# NOTE(yongji): Before we schedule any requests on the new workers,
# we should wait for them to switch to the new setup.
await wait_future
# Update the parallel config
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
# Notify coordinator about scale up through existing
# stats_update_task connection
self._ensure_stats_update_task()
scale_up_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size)
)
await self.first_req_send_socket.send(scale_up_marker)
logger.info(
"[Elastic EP] Scale up completed, new data parallel size: %s",
new_data_parallel_size,
)
async def _scale_down_elastic_ep(
self, cur_data_parallel_size: int, new_data_parallel_size: int
) -> None:
"""Scale down the data parallel size by shutting down and
reconfiguring existing engine cores."""
cur_data_parallel_size = len(self.core_engines)
self.eep_scaling_cache = ElasticScalingCache(
existing_core_engines=self.core_engines.copy(),
num_new_core_engines=new_data_parallel_size - cur_data_parallel_size,
pending_notifications=dict(),
)
parallel_config = self.vllm_config.parallel_config
ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap()
removed_dp_size = cur_data_parallel_size - new_data_parallel_size
assert isinstance(self.resources.engine_manager, CoreEngineActorManager)
self.resources.engine_manager.remove_run_refs_for_scale_down(removed_dp_size)
reconfig_futures = []
for cur_dp_rank, engine in enumerate(self.core_engines):
reconfig_request = ReconfigureDistributedRequest(
new_data_parallel_size=new_data_parallel_size,
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_rank_local=ReconfigureRankType.KEEP_CURRENT_RANK,
new_data_parallel_master_ip=ip,
new_data_parallel_master_port=parallel_config.data_parallel_master_port,
new_data_parallel_master_port_list=parallel_config._data_parallel_master_port_list,
coord_store_port=coord_store_port,
)
if cur_dp_rank >= new_data_parallel_size:
reconfig_request.new_data_parallel_rank = (
ReconfigureRankType.SHUTDOWN_CURRENT_RANK
)
coro = self._call_utility_async(
"reinitialize_distributed", reconfig_request, engine=engine
)
reconfig_futures.append(asyncio.create_task(coro))
# NOTE(yongji): Immediately stop sending requests to the removing engines.
self.core_engines = self.core_engines[:new_data_parallel_size]
self.lb_engines = self.lb_engines[:new_data_parallel_size]
wait_future = self._eep_wait_for_setup_switch_complete()
await asyncio.gather(*reconfig_futures)
self.vllm_config.parallel_config.data_parallel_size = new_data_parallel_size
self._ensure_stats_update_task()
scale_down_marker = msgspec.msgpack.encode(
("SCALE_ELASTIC_EP", new_data_parallel_size)
)
await self.first_req_send_socket.send(scale_down_marker)
# NOTE(yongji): Unlike scaling up,
# here we don't actually need to wait for the setup switch to complete.
# We may want to remove it in the future.
await wait_future
logger.info(
"[Elastic EP] Scale down completed, new data parallel size: %s",
new_data_parallel_size,
)