[P/D] Refactor mooncake connector sender thread using async coroutines (#31573)

Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
dtc
2026-01-12 20:35:35 +08:00
committed by GitHub
parent 9dbe1fe960
commit 0565f1fdec

View File

@@ -3,7 +3,6 @@
import asyncio
import threading
import time
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
@@ -83,28 +82,10 @@ class RecvReqMeta:
@dataclass
class SendBlockMeta:
local_block_ids: list[int]
ready: threading.Event
ready: asyncio.Event
expire_time: float = float("inf")
@dataclass
class SendReqMeta:
reqs: dict[ReqId, SendBlockMeta]
lock: threading.Lock
@dataclass
class FinishedSendReqSet:
set: set[ReqId]
lock: threading.Lock
@dataclass
class FinishedReceiveReqSet:
set: set[ReqId]
lock: asyncio.Lock
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
@@ -437,39 +418,50 @@ class MooncakeConnectorWorker:
assert vllm_config.kv_transfer_config
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
self.num_sender_workers = (
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"num_workers", 10
)
)
# Create more tasks than workers to keep the thread pool saturated.
# Tasks can await async events, so a surplus (2x is a robust heuristic)
# prevents workers from idling.
self.num_sender_tasks = self.num_sender_workers * 2
self.kv_caches_base_addr: list[int] = []
self.device_kv_caches: dict[str, torch.Tensor] = {}
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
self.reqs_need_send: dict[ReqId, SendBlockMeta] = {}
# For kv_both, we will act both prefiller and decoder.
if self.kv_role != "kv_consumer":
# Background thread for sending kvcaches to D.
self._mooncake_sender_t: threading.Thread | None = None
# Background thread for processing new sending requests.
# Background threads for sending kvcaches to D.
self._sender_executor = ThreadPoolExecutor(
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
max_workers=self.num_sender_workers,
thread_name_prefix="vllm-mooncake-sender",
)
logger.debug(
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
"Mooncake Prefiller: use %d workers to send kvcaches",
self.num_sender_workers,
)
# An asyncio queue to buffer incoming requests for the sender
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
self.sender_loop = asyncio.new_event_loop()
# Background thread for processing new sending requests.
self._sender_listener_t = threading.Thread(
target=_async_loop, args=(self.sender_loop,), daemon=True
)
self._sender_listener_t.start()
if self.kv_role != "kv_producer":
self.receiver_loop = asyncio.new_event_loop()
self._mooncake_receiver_t = threading.Thread(
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
target=_async_loop, args=(self.receiver_loop,), daemon=True
)
self._mooncake_receiver_t.start()
logger.debug("Mooncake Decoder: start receiver thread")
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
set(), threading.Lock()
)
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
set(), asyncio.Lock()
)
self.finished_sending_reqs: set[ReqId] = set()
self.finished_recving_reqs: set[ReqId] = set()
self.block_size = vllm_config.cache_config.block_size
self.model_config = vllm_config.model_config
@@ -500,7 +492,6 @@ class MooncakeConnectorWorker:
attn_backend=backend,
)
self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context()
self._encoder = msgspec.msgpack.Encoder()
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
@@ -510,21 +501,17 @@ class MooncakeConnectorWorker:
def shutdown(self):
"""Cleanup background threads on destruction."""
self.zmq_ctx.term()
self.async_zmq_ctx.term()
if self.kv_role != "kv_consumer":
self._sender_executor.shutdown(wait=False)
if self._mooncake_sender_t:
self._mooncake_sender_t.join()
if self.sender_loop.is_running():
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
self._sender_listener_t.join()
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
self._mooncake_receiver_t.join()
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()
def _mooncake_sender(
async def _mooncake_sender_listener(
self, ready_event: threading.Event, base_port: int, tp_rank: int
):
"""
@@ -532,72 +519,55 @@ class MooncakeConnectorWorker:
to a thread pool, and sends acknowledgments upon completion.
"""
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER)
logger.debug("Mooncake sender starting listening on path: %s", path)
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
poller = zmq.Poller()
poller.register(frontend, zmq.POLLIN)
poller.register(backend, zmq.POLLIN)
# Create async worker tasks that process items from the queue
sender_tasks = [
asyncio.create_task(self._sender_worker(sock))
for _ in range(self.num_sender_tasks)
]
ready_event.set()
try:
while True:
sockets = dict(poller.poll())
if frontend in sockets:
identity, _, metadata_bytes = frontend.recv_multipart()
self._sender_executor.submit(
self._sender_worker,
identity,
metadata_bytes,
backend_path,
)
if backend in sockets:
identity, status = backend.recv_multipart()
frontend.send_multipart((identity, b"", status))
identity, _, metadata_bytes = await sock.recv_multipart()
await self.sender_worker_queue.put((identity, metadata_bytes))
except zmq.ContextTerminated:
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
except Exception as e:
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
finally:
frontend.close()
backend.close()
def _sender_worker(
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
):
status = TRANS_ERROR
# Clean up worker tasks
for task in sender_tasks:
task.cancel()
await asyncio.gather(*sender_tasks, return_exceptions=True)
sock.close()
async def _sender_worker(self, sock: zmq.asyncio.Socket):
while True:
try:
identity, metadata_bytes = await self.sender_worker_queue.get()
try:
metadata = self._decoder.decode(metadata_bytes)
self.send_kv_to_decode(metadata)
status = TRANS_DONE
await self.send_kv_to_decode(metadata)
await sock.send_multipart((identity, b"", TRANS_DONE))
except Exception as e:
logger.error("Error processing Mooncake handshake: %s", e)
logger.error("Error processing Mooncake xfer request: %s", e)
await sock.send_multipart((identity, b"", TRANS_ERROR))
finally:
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
try:
pusher.send_multipart((identity, status))
except zmq.ZMQError as e:
logger.warning(
"Internal error, maybe the server is shutting down. Error: %s",
e,
)
finally:
pusher.close()
self.sender_worker_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error("Error in _sender_worker: %s", e)
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
async def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
send_meta = self.reqs_need_send.reqs.get(req_id)
send_meta = self.reqs_need_send.get(req_id)
if send_meta is None:
logger.warning("Request %s not found in reqs_need_send", req_id)
return
@@ -605,20 +575,30 @@ class MooncakeConnectorWorker:
send_meta.expire_time = float("inf")
send_reqs.append((req_id, send_meta))
self._send_blocks(send_reqs, meta)
src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta)
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
ret_value = await self.sender_loop.run_in_executor(
self._sender_executor,
self._send_blocks,
remote_session,
src_ptrs,
dst_ptrs,
lengths,
)
if ret_value != 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
with self.reqs_need_send.lock:
for req_id in meta.request_ids:
del self.reqs_need_send.reqs[req_id]
del self.reqs_need_send[req_id]
with self.finished_sending_reqs.lock:
self.finished_sending_reqs.set.update(meta.request_ids)
self.finished_sending_reqs.update(meta.request_ids)
def _send_blocks(
async def _build_transfer_params(
self,
send_reqs: list[tuple[ReqId, SendBlockMeta]],
agent_meta: MooncakeAgentMetadata,
):
) -> tuple[list[int], list[int], list[int]]:
src_ptrs = []
dst_ptrs = []
lengths = []
@@ -631,7 +611,7 @@ class MooncakeConnectorWorker:
for (req_id, send_meta), remote_block_ids in zip(
send_reqs, agent_meta.block_ids
):
send_meta.ready.wait()
await send_meta.ready.wait()
num_remote_blocks = len(remote_block_ids)
if num_remote_blocks == 0:
@@ -670,18 +650,26 @@ class MooncakeConnectorWorker:
remote_session,
)
return src_ptrs, dst_ptrs, lengths
def _send_blocks(
self,
remote_session: str,
src_ptrs: list[int],
dst_ptrs: list[int],
lengths: list[int],
) -> int:
start_time = time.perf_counter()
ret_value = self.engine.batch_transfer_sync_write(
remote_session, src_ptrs, dst_ptrs, lengths
)
if ret_value != 0:
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
if ret_value == 0:
logger.debug(
"Sending to %s done, took %s",
remote_session,
time.perf_counter() - start_time,
)
return ret_value
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in mooncake."""
@@ -740,57 +728,28 @@ class MooncakeConnectorWorker:
return
ready_event = threading.Event()
self._mooncake_sender_t = threading.Thread(
target=self._mooncake_sender,
args=(ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="mooncake_sender",
asyncio.run_coroutine_threadsafe(
self._mooncake_sender_listener(
ready_event, self.side_channel_port, self.tp_rank
),
self.sender_loop,
)
self._mooncake_sender_t.start()
ready_event.wait() # Wait for listener ZMQ socket to be ready.
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
async with self.finished_recving_reqs.lock:
finished_recving_reqs = self.finished_recving_reqs.set
self.finished_recving_reqs.set = set()
finished_recving_reqs = self.finished_recving_reqs
self.finished_recving_reqs = set()
return finished_recving_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
fut = None
if self.kv_role != "kv_producer":
fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop
)
if self.kv_role != "kv_consumer":
with self.finished_sending_reqs.lock:
finished_sending_reqs = self.finished_sending_reqs.set
self.finished_sending_reqs.set = set()
else:
finished_sending_reqs = set()
finished_recving_reqs = fut.result() if fut else set()
if finished_sending_reqs or finished_recving_reqs:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(finished_sending_reqs),
len(finished_recving_reqs),
)
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
finished_sending_reqs = self.finished_sending_reqs
self.finished_sending_reqs = set()
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
with self.reqs_need_send.lock:
expired_reqs = [
req_id
for req_id, send_meta in self.reqs_need_send.reqs.items()
for req_id, send_meta in self.reqs_need_send.items()
if send_meta.expire_time < now
]
for req_id in expired_reqs:
@@ -800,10 +759,42 @@ class MooncakeConnectorWorker:
req_id,
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
)
del self.reqs_need_send.reqs[req_id]
del self.reqs_need_send[req_id]
if expired_reqs:
finished_sending_reqs.update(expired_reqs)
return finished_sending_reqs
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
"""
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
recv_fut = None
send_fut = None
if self.kv_role != "kv_producer":
recv_fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_recving_reqs(), self.receiver_loop
)
if self.kv_role != "kv_consumer":
send_fut = asyncio.run_coroutine_threadsafe(
self.fetch_finished_sending_reqs(), self.sender_loop
)
finished_recving_reqs = recv_fut.result() if recv_fut else set()
finished_sending_reqs = send_fut.result() if send_fut else set()
if finished_sending_reqs or finished_recving_reqs:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving",
self.tp_rank,
len(finished_sending_reqs),
len(finished_recving_reqs),
)
return finished_sending_reqs or None, finished_recving_reqs or None
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
@@ -844,8 +835,7 @@ class MooncakeConnectorWorker:
finally:
sock.close()
async with self.finished_recving_reqs.lock:
self.finished_recving_reqs.set.update(req_ids)
self.finished_recving_reqs.update(req_ids)
logger.debug("pulling kv_caches for %s finished", req_ids)
@@ -865,6 +855,24 @@ class MooncakeConnectorWorker:
return kv_pulls
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send[req_id]
send_meta.local_block_ids = block_ids
send_meta.expire_time = (
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
send_meta.ready.set()
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send[req_id] = SendBlockMeta(
local_block_ids=[],
ready=asyncio.Event(),
)
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
if self.kv_role != "kv_producer":
kv_pulls = self.group_kv_pull(metadata)
@@ -874,22 +882,8 @@ class MooncakeConnectorWorker:
)
if self.kv_role != "kv_consumer":
with self.reqs_need_send.lock:
for req_id, block_ids in metadata.reqs_to_send.items():
if block_ids:
# Already gone through request_finished()
send_meta = self.reqs_need_send.reqs[req_id]
send_meta.local_block_ids = block_ids
send_meta.ready.set()
send_meta.expire_time = (
time.perf_counter()
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
)
else:
# From update_state_after_alloc(),
# but not reach request_finished() yet
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
local_block_ids=[], ready=threading.Event()
asyncio.run_coroutine_threadsafe(
self.record_send_reqs(metadata), self.sender_loop
)
@@ -917,3 +911,8 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
+ vllm_config.parallel_config.data_parallel_index
* vllm_config.parallel_config.tensor_parallel_size
)
def _async_loop(loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
loop.run_forever()