[Perf] API-server scaleout with many-to-many server-engine comms (#17546)
This commit is contained in:
@@ -7,6 +7,7 @@ import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from contextlib import ExitStack
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
@@ -22,7 +23,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname, zmq_socket_ctx
|
||||
from vllm.utils import make_zmq_socket, resolve_obj_by_qualname
|
||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
@@ -33,10 +34,12 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -211,7 +214,7 @@ class EngineCore:
|
||||
# Re-raise exception
|
||||
raise err
|
||||
|
||||
def step(self) -> tuple[EngineCoreOutputs, bool]:
|
||||
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
|
||||
"""Schedule, execute, and make output.
|
||||
|
||||
Returns tuple of outputs and a flag indicating whether the model
|
||||
@@ -221,10 +224,7 @@ class EngineCore:
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return EngineCoreOutputs(
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
), False
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
model_output = self.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
@@ -234,7 +234,7 @@ class EngineCore:
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[EngineCoreOutputs], bool]:
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
Note that if nothing to output in this step, None is returned.
|
||||
|
||||
@@ -276,8 +276,8 @@ class EngineCore:
|
||||
# Blocking until the first result is available.
|
||||
model_output = future.result()
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
|
||||
@@ -362,7 +362,7 @@ class EngineCoreProc(EngineCore):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
engine_index: int = 0,
|
||||
@@ -375,65 +375,70 @@ class EngineCoreProc(EngineCore):
|
||||
# Create input socket.
|
||||
input_ctx = zmq.Context()
|
||||
identity = engine_index.to_bytes(length=2, byteorder="little")
|
||||
input_socket = make_zmq_socket(input_ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False)
|
||||
try:
|
||||
with make_zmq_socket(input_ctx,
|
||||
handshake_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
linger=5000,
|
||||
bind=False) as handshake_socket:
|
||||
|
||||
# Register engine with front-end.
|
||||
output_address = self.startup_handshake(
|
||||
input_socket, on_head_node, vllm_config.parallel_config)
|
||||
addresses = self.startup_handshake(handshake_socket, on_head_node,
|
||||
vllm_config.parallel_config)
|
||||
self.client_count = len(addresses.outputs)
|
||||
|
||||
# Update config which may have changed from the handshake.
|
||||
vllm_config.__post_init__()
|
||||
|
||||
# Set up data parallel environment.
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self._init_data_parallel(vllm_config)
|
||||
|
||||
# Initialize engine core and model.
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
self.engine_index = engine_index
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
self.engines_running = False
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
# Send ready message.
|
||||
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
|
||||
input_socket.send(
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "READY",
|
||||
"local": on_head_node,
|
||||
"num_gpu_blocks": num_gpu_blocks,
|
||||
}))
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[EngineCoreOutputs, bytes]]()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_socket, ),
|
||||
daemon=True).start()
|
||||
input_socket = None
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_socket,
|
||||
args=(output_address, engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
finally:
|
||||
if input_socket is not None:
|
||||
input_socket.close(linger=0)
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
|
||||
bytes]]()
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@staticmethod
|
||||
def startup_handshake(input_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> str:
|
||||
def startup_handshake(
|
||||
handshake_socket: zmq.Socket, on_head_node: bool,
|
||||
parallel_config: ParallelConfig) -> EngineZmqAddresses:
|
||||
|
||||
# Send registration message.
|
||||
input_socket.send(
|
||||
handshake_socket.send(
|
||||
msgspec.msgpack.encode({
|
||||
"status": "HELLO",
|
||||
"local": on_head_node,
|
||||
@@ -441,22 +446,20 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# Receive initialization message.
|
||||
logger.info("Waiting for init message from front-end.")
|
||||
if not input_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60 * 1000):
|
||||
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
|
||||
raise RuntimeError("Did not receive response from front-end "
|
||||
f"process within {HANDSHAKE_TIMEOUT_MINS} "
|
||||
f"minutes")
|
||||
init_bytes = input_socket.recv()
|
||||
init_message = msgspec.msgpack.decode(init_bytes)
|
||||
init_bytes = handshake_socket.recv()
|
||||
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
|
||||
init_bytes, type=EngineHandshakeMetadata)
|
||||
logger.debug("Received init message: %s", init_message)
|
||||
|
||||
output_socket_address = init_message["output_socket_address"]
|
||||
#TBD(nick) maybe replace IP with configured head node address
|
||||
|
||||
received_parallel_config = init_message["parallel_config"]
|
||||
received_parallel_config = init_message.parallel_config
|
||||
for key, value in received_parallel_config.items():
|
||||
setattr(parallel_config, key, value)
|
||||
|
||||
return output_socket_address
|
||||
return init_message.addresses
|
||||
|
||||
@staticmethod
|
||||
def run_engine_core(*args,
|
||||
@@ -528,7 +531,7 @@ class EngineCoreProc(EngineCore):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not (self.scheduler.has_requests()):
|
||||
while not self.engines_running and not self.scheduler.has_requests():
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
@@ -549,8 +552,8 @@ class EngineCoreProc(EngineCore):
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
if outputs is not None:
|
||||
self.output_queue.put_nowait(outputs)
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
return model_executed
|
||||
|
||||
@@ -563,7 +566,7 @@ class EngineCoreProc(EngineCore):
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
client_idx, call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
@@ -574,7 +577,7 @@ class EngineCoreProc(EngineCore):
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(utility_output=output))
|
||||
(client_idx, EngineCoreOutputs(utility_output=output)))
|
||||
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
|
||||
raise RuntimeError("Executor failed.")
|
||||
else:
|
||||
@@ -607,27 +610,68 @@ class EngineCoreProc(EngineCore):
|
||||
logger.fatal("vLLM shutdown signal from EngineCore failed "
|
||||
"to send. Please report this issue.")
|
||||
|
||||
def process_input_socket(self, input_socket: zmq.Socket):
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(copy=False)
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
input_sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
input_address,
|
||||
zmq.DEALER,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
for input_address in input_addresses
|
||||
]
|
||||
if coord_input_address is None:
|
||||
coord_socket = None
|
||||
else:
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(ctx,
|
||||
coord_input_address,
|
||||
zmq.XSUB,
|
||||
identity=identity,
|
||||
bind=False))
|
||||
# Send subscription message to coordinator.
|
||||
coord_socket.send(b'\x01')
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type == EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
# Register sockets with poller.
|
||||
poller = zmq.Poller()
|
||||
for input_socket in input_sockets:
|
||||
# Send initial message to each input socket - this is required
|
||||
# before the front-end ROUTER socket can send input messages
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
if coord_socket is not None:
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
type_frame, *data_frames = input_socket.recv_multipart(
|
||||
copy=False)
|
||||
request_type = EngineCoreRequestType(
|
||||
bytes(type_frame.buffer))
|
||||
|
||||
def process_output_socket(self, output_path: str, engine_index: int):
|
||||
# Deserialize the request data.
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frames)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_sockets(self, output_paths: list[str],
|
||||
coord_output_path: Optional[str],
|
||||
engine_index: int):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
@@ -641,30 +685,49 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
with zmq_socket_ctx(output_path, zmq.constants.PUSH,
|
||||
linger=4000) as socket:
|
||||
with ExitStack() as stack, zmq.Context() as ctx:
|
||||
sockets = [
|
||||
stack.enter_context(
|
||||
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
|
||||
for output_path in output_paths
|
||||
]
|
||||
coord_socket = stack.enter_context(
|
||||
make_zmq_socket(
|
||||
ctx, coord_output_path, zmq.PUSH, bind=False,
|
||||
linger=4000)) if coord_output_path is not None else None
|
||||
max_reuse_bufs = len(sockets) + 1
|
||||
|
||||
while True:
|
||||
outputs = self.output_queue.get()
|
||||
if outputs == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
socket.send(outputs, copy=False)
|
||||
output = self.output_queue.get()
|
||||
if output == EngineCoreProc.ENGINE_CORE_DEAD:
|
||||
for socket in sockets:
|
||||
socket.send(output)
|
||||
break
|
||||
assert not isinstance(outputs, bytes)
|
||||
assert not isinstance(output, bytes)
|
||||
client_index, outputs = output
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
if client_index == -1:
|
||||
# Don't reuse buffer for coordinator message
|
||||
# which will be very small.
|
||||
assert coord_socket is not None
|
||||
coord_socket.send_multipart(encoder.encode(outputs))
|
||||
continue
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
tracker = socket.send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
tracker = sockets[client_index].send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < 2:
|
||||
# Keep at most 2 buffers to reuse.
|
||||
elif len(reuse_buffers) < max_reuse_bufs:
|
||||
# Limit the number of buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
@@ -676,7 +739,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
on_head_node: bool,
|
||||
input_address: str,
|
||||
handshake_address: str,
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
@@ -691,10 +754,11 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.current_wave = 0
|
||||
|
||||
# Initialize the engine.
|
||||
dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
super().__init__(vllm_config, on_head_node, input_address,
|
||||
super().__init__(vllm_config, on_head_node, handshake_address,
|
||||
executor_class, log_stats, dp_rank)
|
||||
|
||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||
@@ -726,7 +790,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
self.dp_rank = dp_rank
|
||||
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
|
||||
self.current_wave = 0
|
||||
|
||||
def shutdown(self):
|
||||
super().shutdown()
|
||||
@@ -734,22 +797,23 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
stateless_destroy_torch_distributed_process_group(dp_group)
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
if request.current_wave != self.current_wave:
|
||||
if self.has_coordinator and request.current_wave != self.current_wave:
|
||||
if request.current_wave > self.current_wave:
|
||||
self.current_wave = request.current_wave
|
||||
elif not self.engines_running:
|
||||
# Request received for an already-completed wave, notify
|
||||
# front-end that we need to start the next one.
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(start_wave=self.current_wave))
|
||||
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
|
||||
|
||||
super().add_request(request)
|
||||
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
if request_type == EngineCoreRequestType.START_DP_WAVE:
|
||||
new_wave: int = request
|
||||
if new_wave >= self.current_wave:
|
||||
new_wave, exclude_eng_index = request
|
||||
if exclude_eng_index != self.engine_index and (
|
||||
new_wave >= self.current_wave):
|
||||
self.current_wave = new_wave
|
||||
if not self.engines_running:
|
||||
logger.debug("EngineCore starting idle loop for wave %d.",
|
||||
@@ -758,6 +822,18 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.has_coordinator:
|
||||
return
|
||||
|
||||
# Publish our request counts (if they've changed).
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
def run_busy_loop(self):
|
||||
"""Core busy loop of the EngineCore for data parallel case."""
|
||||
|
||||
@@ -768,6 +844,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# 2) Step the engine core.
|
||||
executed = self._process_engine_step()
|
||||
self._maybe_publish_request_counts()
|
||||
|
||||
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
|
||||
if not executed:
|
||||
if not local_unfinished_reqs and not self.engines_running:
|
||||
@@ -788,7 +866,8 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
logger.debug("Wave %d finished, pausing engine loop.",
|
||||
self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(wave_complete=self.current_wave))
|
||||
(-1,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
self.current_wave += 1
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user