[P/D] Refactor mooncake connector sender thread using async coroutines (#31573)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
dtc
2026-01-12 20:35:35 +08:00
committed by GitHub
parent 9dbe1fe960
commit 0565f1fdec

View File

@@ -3,7 +3,6 @@
import asyncio import asyncio
import threading import threading
import time import time
import uuid
from collections import defaultdict from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
@@ -83,28 +82,10 @@ class RecvReqMeta:
@dataclass @dataclass
class SendBlockMeta: class SendBlockMeta:
local_block_ids: list[int] local_block_ids: list[int]
ready: threading.Event ready: asyncio.Event
expire_time: float = float("inf") expire_time: float = float("inf")
@dataclass
class SendReqMeta:
reqs: dict[ReqId, SendBlockMeta]
lock: threading.Lock
@dataclass
class FinishedSendReqSet:
set: set[ReqId]
lock: threading.Lock
@dataclass
class FinishedReceiveReqSet:
set: set[ReqId]
lock: asyncio.Lock
class MooncakeConnectorMetadata(KVConnectorMetadata): class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self): def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {} self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
@@ -437,39 +418,50 @@ class MooncakeConnectorWorker:
assert vllm_config.kv_transfer_config assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get( self.num_sender_workers = (
"num_workers", 10 vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"num_workers", 10
)
) )
# Create more tasks than workers to keep the thread pool saturated.
# Tasks can await async events, so a surplus (2x is a robust heuristic)
# prevents workers from idling.
self.num_sender_tasks = self.num_sender_workers * 2
self.kv_caches_base_addr: list[int] = [] self.kv_caches_base_addr: list[int] = []
self.device_kv_caches: dict[str, torch.Tensor] = {} self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock()) self.reqs_need_send: dict[ReqId, SendBlockMeta] = {}
# For kv_both, we will act both prefiller and decoder. # For kv_both, we will act both prefiller and decoder.
if self.kv_role != "kv_consumer": if self.kv_role != "kv_consumer":
# Background thread for sending kvcaches to D. # Background threads for sending kvcaches to D.
self._mooncake_sender_t: threading.Thread | None = None
# Background thread for processing new sending requests.
self._sender_executor = ThreadPoolExecutor( self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender" max_workers=self.num_sender_workers,
thread_name_prefix="vllm-mooncake-sender",
) )
logger.debug( logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers "Mooncake Prefiller: use %d workers to send kvcaches",
self.num_sender_workers,
) )
# An asyncio queue to buffer incoming requests for the sender
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
self.sender_loop = asyncio.new_event_loop()
# Background thread for processing new sending requests.
self._sender_listener_t = threading.Thread(
target=_async_loop, args=(self.sender_loop,), daemon=True
)
self._sender_listener_t.start()
if self.kv_role != "kv_producer": if self.kv_role != "kv_producer":
self.receiver_loop = asyncio.new_event_loop() self.receiver_loop = asyncio.new_event_loop()
self._mooncake_receiver_t = threading.Thread( self._mooncake_receiver_t = threading.Thread(
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True target=_async_loop, args=(self.receiver_loop,), daemon=True
) )
self._mooncake_receiver_t.start() self._mooncake_receiver_t.start()
logger.debug("Mooncake Decoder: start receiver thread") logger.debug("Mooncake Decoder: start receiver thread")
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet( self.finished_sending_reqs: set[ReqId] = set()
set(), threading.Lock() self.finished_recving_reqs: set[ReqId] = set()
)
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
set(), asyncio.Lock()
)
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
@@ -500,7 +492,6 @@ class MooncakeConnectorWorker:
attn_backend=backend, attn_backend=backend,
) )
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context() self.async_zmq_ctx = zmq.asyncio.Context()
self._encoder = msgspec.msgpack.Encoder() self._encoder = msgspec.msgpack.Encoder()
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata) self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
@@ -510,21 +501,17 @@ class MooncakeConnectorWorker:
def shutdown(self): def shutdown(self):
"""Cleanup background threads on destruction.""" """Cleanup background threads on destruction."""
self.zmq_ctx.term()
self.async_zmq_ctx.term() self.async_zmq_ctx.term()
if self.kv_role != "kv_consumer": if self.kv_role != "kv_consumer":
self._sender_executor.shutdown(wait=False) self._sender_executor.shutdown(wait=False)
if self._mooncake_sender_t: if self.sender_loop.is_running():
self._mooncake_sender_t.join() self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
self._sender_listener_t.join()
if self.kv_role != "kv_producer" and self.receiver_loop.is_running(): if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop) self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
self._mooncake_receiver_t.join() self._mooncake_receiver_t.join()
def _receiver_loop(self, loop: asyncio.AbstractEventLoop): async def _mooncake_sender_listener(
asyncio.set_event_loop(loop)
loop.run_forever()
def _mooncake_sender(
self, ready_event: threading.Event, base_port: int, tp_rank: int self, ready_event: threading.Event, base_port: int, tp_rank: int
): ):
""" """
@@ -532,93 +519,86 @@ class MooncakeConnectorWorker:
to a thread pool, and sends acknowledgments upon completion. to a thread pool, and sends acknowledgments upon completion.
""" """
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank) path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER) sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", frontend_path) logger.debug("Mooncake sender starting listening on path: %s", path)
backend_path = make_zmq_path("inproc", str(uuid.uuid4())) # Create async worker tasks that process items from the queue
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL) sender_tasks = [
asyncio.create_task(self._sender_worker(sock))
poller = zmq.Poller() for _ in range(self.num_sender_tasks)
poller.register(frontend, zmq.POLLIN) ]
poller.register(backend, zmq.POLLIN)
ready_event.set() ready_event.set()
try: try:
while True: while True:
sockets = dict(poller.poll()) identity, _, metadata_bytes = await sock.recv_multipart()
await self.sender_worker_queue.put((identity, metadata_bytes))
if frontend in sockets:
identity, _, metadata_bytes = frontend.recv_multipart()
self._sender_executor.submit(
self._sender_worker,
identity,
metadata_bytes,
backend_path,
)
if backend in sockets:
identity, status = backend.recv_multipart()
frontend.send_multipart((identity, b"", status))
except zmq.ContextTerminated: except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.") logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
except Exception as e: except Exception as e:
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e)) logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
finally: finally:
frontend.close() # Clean up worker tasks
backend.close() for task in sender_tasks:
task.cancel()
await asyncio.gather(*sender_tasks, return_exceptions=True)
sock.close()
def _sender_worker( async def _sender_worker(self, sock: zmq.asyncio.Socket):
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str while True:
):
status = TRANS_ERROR
try:
metadata = self._decoder.decode(metadata_bytes)
self.send_kv_to_decode(metadata)
status = TRANS_DONE
except Exception as e:
logger.error("Error processing Mooncake handshake: %s", e)
finally:
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
try: try:
pusher.send_multipart((identity, status)) identity, metadata_bytes = await self.sender_worker_queue.get()
except zmq.ZMQError as e: try:
logger.warning( metadata = self._decoder.decode(metadata_bytes)
"Internal error, maybe the server is shutting down. Error: %s", await self.send_kv_to_decode(metadata)
e, await sock.send_multipart((identity, b"", TRANS_DONE))
) except Exception as e:
finally: logger.error("Error processing Mooncake xfer request: %s", e)
pusher.close() await sock.send_multipart((identity, b"", TRANS_ERROR))
finally:
self.sender_worker_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in _sender_worker: %s", e)
def send_kv_to_decode(self, meta: MooncakeAgentMetadata): async def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
send_reqs: list[tuple[ReqId, SendBlockMeta]] = [] send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
with self.reqs_need_send.lock: for req_id in meta.request_ids:
for req_id in meta.request_ids: send_meta = self.reqs_need_send.get(req_id)
send_meta = self.reqs_need_send.reqs.get(req_id) if send_meta is None:
if send_meta is None: logger.warning("Request %s not found in reqs_need_send", req_id)
logger.warning("Request %s not found in reqs_need_send", req_id) return
return # Mark it as not expired. We will send it now.
# Mark it as not expired. We will send it now. send_meta.expire_time = float("inf")
send_meta.expire_time = float("inf") send_reqs.append((req_id, send_meta))
send_reqs.append((req_id, send_meta))
self._send_blocks(send_reqs, meta) src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta)
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
ret_value = await self.sender_loop.run_in_executor(
self._sender_executor,
self._send_blocks,
remote_session,
src_ptrs,
dst_ptrs,
lengths,
)
with self.reqs_need_send.lock: if ret_value != 0:
for req_id in meta.request_ids: raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
del self.reqs_need_send.reqs[req_id]
with self.finished_sending_reqs.lock: for req_id in meta.request_ids:
self.finished_sending_reqs.set.update(meta.request_ids) del self.reqs_need_send[req_id]
def _send_blocks( self.finished_sending_reqs.update(meta.request_ids)
async def _build_transfer_params(
self, self,
send_reqs: list[tuple[ReqId, SendBlockMeta]], send_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeAgentMetadata, agent_meta: MooncakeAgentMetadata,
): ) -> tuple[list[int], list[int], list[int]]:
src_ptrs = [] src_ptrs = []
dst_ptrs = [] dst_ptrs = []
lengths = [] lengths = []
@@ -631,7 +611,7 @@ class MooncakeConnectorWorker:
for (req_id, send_meta), remote_block_ids in zip( for (req_id, send_meta), remote_block_ids in zip(
send_reqs, agent_meta.block_ids send_reqs, agent_meta.block_ids
): ):
send_meta.ready.wait() await send_meta.ready.wait()
num_remote_blocks = len(remote_block_ids) num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0: if num_remote_blocks == 0:
@@ -670,18 +650,26 @@ class MooncakeConnectorWorker:
remote_session, remote_session,
) )
return src_ptrs, dst_ptrs, lengths
def _send_blocks(
self,
remote_session: str,
src_ptrs: list[int],
dst_ptrs: list[int],
lengths: list[int],
) -> int:
start_time = time.perf_counter() start_time = time.perf_counter()
ret_value = self.engine.batch_transfer_sync_write( ret_value = self.engine.batch_transfer_sync_write(
remote_session, src_ptrs, dst_ptrs, lengths remote_session, src_ptrs, dst_ptrs, lengths
) )
if ret_value != 0: if ret_value == 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}") logger.debug(
"Sending to %s done, took %s",
logger.debug( remote_session,
"Sending to %s done, took %s", time.perf_counter() - start_time,
remote_session, )
time.perf_counter() - start_time, return ret_value
)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in mooncake.""" """Register the KV Cache data in mooncake."""
@@ -740,41 +728,63 @@ class MooncakeConnectorWorker:
return return
ready_event = threading.Event() ready_event = threading.Event()
self._mooncake_sender_t = threading.Thread( asyncio.run_coroutine_threadsafe(
target=self._mooncake_sender, self._mooncake_sender_listener(
args=(ready_event, self.side_channel_port, self.tp_rank), ready_event, self.side_channel_port, self.tp_rank
daemon=True, ),
name="mooncake_sender", self.sender_loop,
) )
self._mooncake_sender_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready. ready_event.wait() # Wait for listener ZMQ socket to be ready.
async def fetch_finished_recving_reqs(self) -> set[ReqId]: async def fetch_finished_recving_reqs(self) -> set[ReqId]:
async with self.finished_recving_reqs.lock: finished_recving_reqs = self.finished_recving_reqs
finished_recving_reqs = self.finished_recving_reqs.set self.finished_recving_reqs = set()
self.finished_recving_reqs.set = set()
return finished_recving_reqs return finished_recving_reqs
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
finished_sending_reqs = self.finished_sending_reqs
self.finished_sending_reqs = set()
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
logger.warning(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side.",
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]: def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
""" """
Get requests that are done sending or recving on this specific worker. Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done. to track which workers are done.
""" """
fut = None recv_fut = None
send_fut = None
if self.kv_role != "kv_producer": if self.kv_role != "kv_producer":
fut = asyncio.run_coroutine_threadsafe( recv_fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop self.fetch_finished_recving_reqs(), self.receiver_loop
) )
if self.kv_role != "kv_consumer": if self.kv_role != "kv_consumer":
with self.finished_sending_reqs.lock: send_fut = asyncio.run_coroutine_threadsafe(
finished_sending_reqs = self.finished_sending_reqs.set self.fetch_finished_sending_reqs(), self.sender_loop
self.finished_sending_reqs.set = set() )
else:
finished_sending_reqs = set()
finished_recving_reqs = fut.result() if fut else set() finished_recving_reqs = recv_fut.result() if recv_fut else set()
finished_sending_reqs = send_fut.result() if send_fut else set()
if finished_sending_reqs or finished_recving_reqs: if finished_sending_reqs or finished_recving_reqs:
logger.debug( logger.debug(
@@ -785,25 +795,6 @@ class MooncakeConnectorWorker:
len(finished_recving_reqs), len(finished_recving_reqs),
) )
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
with self.reqs_need_send.lock:
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.reqs.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
logger.warning(
"Request %s timed out after %d seconds without "
"being sent. Freeing its blocks on the producer side.",
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send.reqs[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs or None, finished_recving_reqs or None return finished_sending_reqs or None, finished_recving_reqs or None
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]): async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
@@ -844,8 +835,7 @@ class MooncakeConnectorWorker:
finally: finally:
sock.close() sock.close()
async with self.finished_recving_reqs.lock: self.finished_recving_reqs.update(req_ids)
self.finished_recving_reqs.set.update(req_ids)
logger.debug("pulling kv_caches for %s finished", req_ids) logger.debug("pulling kv_caches for %s finished", req_ids)
@@ -865,6 +855,24 @@ class MooncakeConnectorWorker:
return kv_pulls return kv_pulls
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send[req_id]
send_meta.local_block_ids = block_ids
send_meta.expire_time = (
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
send_meta.ready.set()
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send[req_id] = SendBlockMeta(
local_block_ids=[],
ready=asyncio.Event(),
)
def start_load_kv(self, metadata: MooncakeConnectorMetadata): def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if self.kv_role != "kv_producer": if self.kv_role != "kv_producer":
kv_pulls = self.group_kv_pull(metadata) kv_pulls = self.group_kv_pull(metadata)
@@ -874,23 +882,9 @@ class MooncakeConnectorWorker:
) )
if self.kv_role != "kv_consumer": if self.kv_role != "kv_consumer":
with self.reqs_need_send.lock: asyncio.run_coroutine_threadsafe(
for req_id, block_ids in metadata.reqs_to_send.items(): self.record_send_reqs(metadata), self.sender_loop
if block_ids: )
# Already gone through request_finished()
send_meta = self.reqs_need_send.reqs[req_id]
send_meta.local_block_ids = block_ids
send_meta.ready.set()
send_meta.expire_time = (
time.perf_counter()
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
local_block_ids=[], ready=threading.Event()
)
def group_concurrent_contiguous( def group_concurrent_contiguous(
@@ -917,3 +911,8 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
+ vllm_config.parallel_config.data_parallel_index + vllm_config.parallel_config.data_parallel_index
* vllm_config.parallel_config.tensor_parallel_size * vllm_config.parallel_config.tensor_parallel_size
) )
def _async_loop(loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()