[P/D] Refactor mooncake connector sender thread using async coroutines (#31573)
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
@@ -83,28 +82,10 @@ class RecvReqMeta:
|
||||
@dataclass
|
||||
class SendBlockMeta:
|
||||
local_block_ids: list[int]
|
||||
ready: threading.Event
|
||||
ready: asyncio.Event
|
||||
expire_time: float = float("inf")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqMeta:
|
||||
reqs: dict[ReqId, SendBlockMeta]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedSendReqSet:
|
||||
set: set[ReqId]
|
||||
lock: threading.Lock
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedReceiveReqSet:
|
||||
set: set[ReqId]
|
||||
lock: asyncio.Lock
|
||||
|
||||
|
||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
||||
@@ -437,39 +418,50 @@ class MooncakeConnectorWorker:
|
||||
|
||||
assert vllm_config.kv_transfer_config
|
||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"num_workers", 10
|
||||
self.num_sender_workers = (
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"num_workers", 10
|
||||
)
|
||||
)
|
||||
# Create more tasks than workers to keep the thread pool saturated.
|
||||
# Tasks can await async events, so a surplus (2x is a robust heuristic)
|
||||
# prevents workers from idling.
|
||||
self.num_sender_tasks = self.num_sender_workers * 2
|
||||
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
|
||||
self.reqs_need_send: dict[ReqId, SendBlockMeta] = {}
|
||||
|
||||
# For kv_both, we will act both prefiller and decoder.
|
||||
if self.kv_role != "kv_consumer":
|
||||
# Background thread for sending kvcaches to D.
|
||||
self._mooncake_sender_t: threading.Thread | None = None
|
||||
# Background thread for processing new sending requests.
|
||||
# Background threads for sending kvcaches to D.
|
||||
self._sender_executor = ThreadPoolExecutor(
|
||||
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
|
||||
max_workers=self.num_sender_workers,
|
||||
thread_name_prefix="vllm-mooncake-sender",
|
||||
)
|
||||
logger.debug(
|
||||
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
|
||||
"Mooncake Prefiller: use %d workers to send kvcaches",
|
||||
self.num_sender_workers,
|
||||
)
|
||||
# An asyncio queue to buffer incoming requests for the sender
|
||||
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
|
||||
self.sender_loop = asyncio.new_event_loop()
|
||||
# Background thread for processing new sending requests.
|
||||
self._sender_listener_t = threading.Thread(
|
||||
target=_async_loop, args=(self.sender_loop,), daemon=True
|
||||
)
|
||||
self._sender_listener_t.start()
|
||||
|
||||
if self.kv_role != "kv_producer":
|
||||
self.receiver_loop = asyncio.new_event_loop()
|
||||
self._mooncake_receiver_t = threading.Thread(
|
||||
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
|
||||
target=_async_loop, args=(self.receiver_loop,), daemon=True
|
||||
)
|
||||
self._mooncake_receiver_t.start()
|
||||
logger.debug("Mooncake Decoder: start receiver thread")
|
||||
|
||||
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
|
||||
set(), threading.Lock()
|
||||
)
|
||||
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
|
||||
set(), asyncio.Lock()
|
||||
)
|
||||
self.finished_sending_reqs: set[ReqId] = set()
|
||||
self.finished_recving_reqs: set[ReqId] = set()
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.model_config = vllm_config.model_config
|
||||
@@ -500,7 +492,6 @@ class MooncakeConnectorWorker:
|
||||
attn_backend=backend,
|
||||
)
|
||||
|
||||
self.zmq_ctx = zmq.Context()
|
||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||
self._encoder = msgspec.msgpack.Encoder()
|
||||
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
@@ -510,21 +501,17 @@ class MooncakeConnectorWorker:
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self.zmq_ctx.term()
|
||||
self.async_zmq_ctx.term()
|
||||
if self.kv_role != "kv_consumer":
|
||||
self._sender_executor.shutdown(wait=False)
|
||||
if self._mooncake_sender_t:
|
||||
self._mooncake_sender_t.join()
|
||||
if self.sender_loop.is_running():
|
||||
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
|
||||
self._sender_listener_t.join()
|
||||
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||
self._mooncake_receiver_t.join()
|
||||
|
||||
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
def _mooncake_sender(
|
||||
async def _mooncake_sender_listener(
|
||||
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
||||
):
|
||||
"""
|
||||
@@ -532,93 +519,86 @@ class MooncakeConnectorWorker:
|
||||
to a thread pool, and sends acknowledgments upon completion.
|
||||
"""
|
||||
|
||||
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
|
||||
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
|
||||
path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||
sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER)
|
||||
logger.debug("Mooncake sender starting listening on path: %s", path)
|
||||
|
||||
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
|
||||
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(frontend, zmq.POLLIN)
|
||||
poller.register(backend, zmq.POLLIN)
|
||||
# Create async worker tasks that process items from the queue
|
||||
sender_tasks = [
|
||||
asyncio.create_task(self._sender_worker(sock))
|
||||
for _ in range(self.num_sender_tasks)
|
||||
]
|
||||
|
||||
ready_event.set()
|
||||
|
||||
try:
|
||||
while True:
|
||||
sockets = dict(poller.poll())
|
||||
|
||||
if frontend in sockets:
|
||||
identity, _, metadata_bytes = frontend.recv_multipart()
|
||||
self._sender_executor.submit(
|
||||
self._sender_worker,
|
||||
identity,
|
||||
metadata_bytes,
|
||||
backend_path,
|
||||
)
|
||||
|
||||
if backend in sockets:
|
||||
identity, status = backend.recv_multipart()
|
||||
frontend.send_multipart((identity, b"", status))
|
||||
|
||||
identity, _, metadata_bytes = await sock.recv_multipart()
|
||||
await self.sender_worker_queue.put((identity, metadata_bytes))
|
||||
except zmq.ContextTerminated:
|
||||
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
||||
except Exception as e:
|
||||
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
||||
finally:
|
||||
frontend.close()
|
||||
backend.close()
|
||||
# Clean up worker tasks
|
||||
for task in sender_tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*sender_tasks, return_exceptions=True)
|
||||
sock.close()
|
||||
|
||||
def _sender_worker(
|
||||
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
|
||||
):
|
||||
status = TRANS_ERROR
|
||||
|
||||
try:
|
||||
metadata = self._decoder.decode(metadata_bytes)
|
||||
self.send_kv_to_decode(metadata)
|
||||
status = TRANS_DONE
|
||||
except Exception as e:
|
||||
logger.error("Error processing Mooncake handshake: %s", e)
|
||||
finally:
|
||||
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
|
||||
async def _sender_worker(self, sock: zmq.asyncio.Socket):
|
||||
while True:
|
||||
try:
|
||||
pusher.send_multipart((identity, status))
|
||||
except zmq.ZMQError as e:
|
||||
logger.warning(
|
||||
"Internal error, maybe the server is shutting down. Error: %s",
|
||||
e,
|
||||
)
|
||||
finally:
|
||||
pusher.close()
|
||||
identity, metadata_bytes = await self.sender_worker_queue.get()
|
||||
try:
|
||||
metadata = self._decoder.decode(metadata_bytes)
|
||||
await self.send_kv_to_decode(metadata)
|
||||
await sock.send_multipart((identity, b"", TRANS_DONE))
|
||||
except Exception as e:
|
||||
logger.error("Error processing Mooncake xfer request: %s", e)
|
||||
await sock.send_multipart((identity, b"", TRANS_ERROR))
|
||||
finally:
|
||||
self.sender_worker_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Error in _sender_worker: %s", e)
|
||||
|
||||
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||
async def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
send_meta = self.reqs_need_send.reqs.get(req_id)
|
||||
if send_meta is None:
|
||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||
return
|
||||
# Mark it as not expired. We will send it now.
|
||||
send_meta.expire_time = float("inf")
|
||||
send_reqs.append((req_id, send_meta))
|
||||
for req_id in meta.request_ids:
|
||||
send_meta = self.reqs_need_send.get(req_id)
|
||||
if send_meta is None:
|
||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||
return
|
||||
# Mark it as not expired. We will send it now.
|
||||
send_meta.expire_time = float("inf")
|
||||
send_reqs.append((req_id, send_meta))
|
||||
|
||||
self._send_blocks(send_reqs, meta)
|
||||
src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta)
|
||||
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
||||
ret_value = await self.sender_loop.run_in_executor(
|
||||
self._sender_executor,
|
||||
self._send_blocks,
|
||||
remote_session,
|
||||
src_ptrs,
|
||||
dst_ptrs,
|
||||
lengths,
|
||||
)
|
||||
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id in meta.request_ids:
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||
|
||||
with self.finished_sending_reqs.lock:
|
||||
self.finished_sending_reqs.set.update(meta.request_ids)
|
||||
for req_id in meta.request_ids:
|
||||
del self.reqs_need_send[req_id]
|
||||
|
||||
def _send_blocks(
|
||||
self.finished_sending_reqs.update(meta.request_ids)
|
||||
|
||||
async def _build_transfer_params(
|
||||
self,
|
||||
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
||||
agent_meta: MooncakeAgentMetadata,
|
||||
):
|
||||
) -> tuple[list[int], list[int], list[int]]:
|
||||
src_ptrs = []
|
||||
dst_ptrs = []
|
||||
lengths = []
|
||||
@@ -631,7 +611,7 @@ class MooncakeConnectorWorker:
|
||||
for (req_id, send_meta), remote_block_ids in zip(
|
||||
send_reqs, agent_meta.block_ids
|
||||
):
|
||||
send_meta.ready.wait()
|
||||
await send_meta.ready.wait()
|
||||
|
||||
num_remote_blocks = len(remote_block_ids)
|
||||
if num_remote_blocks == 0:
|
||||
@@ -670,18 +650,26 @@ class MooncakeConnectorWorker:
|
||||
remote_session,
|
||||
)
|
||||
|
||||
return src_ptrs, dst_ptrs, lengths
|
||||
|
||||
def _send_blocks(
|
||||
self,
|
||||
remote_session: str,
|
||||
src_ptrs: list[int],
|
||||
dst_ptrs: list[int],
|
||||
lengths: list[int],
|
||||
) -> int:
|
||||
start_time = time.perf_counter()
|
||||
ret_value = self.engine.batch_transfer_sync_write(
|
||||
remote_session, src_ptrs, dst_ptrs, lengths
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||
|
||||
logger.debug(
|
||||
"Sending to %s done, took %s",
|
||||
remote_session,
|
||||
time.perf_counter() - start_time,
|
||||
)
|
||||
if ret_value == 0:
|
||||
logger.debug(
|
||||
"Sending to %s done, took %s",
|
||||
remote_session,
|
||||
time.perf_counter() - start_time,
|
||||
)
|
||||
return ret_value
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in mooncake."""
|
||||
@@ -740,41 +728,63 @@ class MooncakeConnectorWorker:
|
||||
return
|
||||
|
||||
ready_event = threading.Event()
|
||||
self._mooncake_sender_t = threading.Thread(
|
||||
target=self._mooncake_sender,
|
||||
args=(ready_event, self.side_channel_port, self.tp_rank),
|
||||
daemon=True,
|
||||
name="mooncake_sender",
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._mooncake_sender_listener(
|
||||
ready_event, self.side_channel_port, self.tp_rank
|
||||
),
|
||||
self.sender_loop,
|
||||
)
|
||||
self._mooncake_sender_t.start()
|
||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||
|
||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||
async with self.finished_recving_reqs.lock:
|
||||
finished_recving_reqs = self.finished_recving_reqs.set
|
||||
self.finished_recving_reqs.set = set()
|
||||
finished_recving_reqs = self.finished_recving_reqs
|
||||
self.finished_recving_reqs = set()
|
||||
return finished_recving_reqs
|
||||
|
||||
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
|
||||
finished_sending_reqs = self.finished_sending_reqs
|
||||
self.finished_sending_reqs = set()
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
expired_reqs = [
|
||||
req_id
|
||||
for req_id, send_meta in self.reqs_need_send.items()
|
||||
if send_meta.expire_time < now
|
||||
]
|
||||
for req_id in expired_reqs:
|
||||
logger.warning(
|
||||
"Request %s timed out after %d seconds without "
|
||||
"being sent. Freeing its blocks on the producer side.",
|
||||
req_id,
|
||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
del self.reqs_need_send[req_id]
|
||||
if expired_reqs:
|
||||
finished_sending_reqs.update(expired_reqs)
|
||||
|
||||
return finished_sending_reqs
|
||||
|
||||
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
The scheduler process (via the MultiprocExecutor) will use this output
|
||||
to track which workers are done.
|
||||
"""
|
||||
fut = None
|
||||
recv_fut = None
|
||||
send_fut = None
|
||||
if self.kv_role != "kv_producer":
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
recv_fut = asyncio.run_coroutine_threadsafe(
|
||||
self.fetch_finished_recving_reqs(), self.receiver_loop
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.finished_sending_reqs.lock:
|
||||
finished_sending_reqs = self.finished_sending_reqs.set
|
||||
self.finished_sending_reqs.set = set()
|
||||
else:
|
||||
finished_sending_reqs = set()
|
||||
send_fut = asyncio.run_coroutine_threadsafe(
|
||||
self.fetch_finished_sending_reqs(), self.sender_loop
|
||||
)
|
||||
|
||||
finished_recving_reqs = fut.result() if fut else set()
|
||||
finished_recving_reqs = recv_fut.result() if recv_fut else set()
|
||||
finished_sending_reqs = send_fut.result() if send_fut else set()
|
||||
|
||||
if finished_sending_reqs or finished_recving_reqs:
|
||||
logger.debug(
|
||||
@@ -785,25 +795,6 @@ class MooncakeConnectorWorker:
|
||||
len(finished_recving_reqs),
|
||||
)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
with self.reqs_need_send.lock:
|
||||
expired_reqs = [
|
||||
req_id
|
||||
for req_id, send_meta in self.reqs_need_send.reqs.items()
|
||||
if send_meta.expire_time < now
|
||||
]
|
||||
for req_id in expired_reqs:
|
||||
logger.warning(
|
||||
"Request %s timed out after %d seconds without "
|
||||
"being sent. Freeing its blocks on the producer side.",
|
||||
req_id,
|
||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||
)
|
||||
del self.reqs_need_send.reqs[req_id]
|
||||
if expired_reqs:
|
||||
finished_sending_reqs.update(expired_reqs)
|
||||
|
||||
return finished_sending_reqs or None, finished_recving_reqs or None
|
||||
|
||||
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
||||
@@ -844,8 +835,7 @@ class MooncakeConnectorWorker:
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async with self.finished_recving_reqs.lock:
|
||||
self.finished_recving_reqs.set.update(req_ids)
|
||||
self.finished_recving_reqs.update(req_ids)
|
||||
|
||||
logger.debug("pulling kv_caches for %s finished", req_ids)
|
||||
|
||||
@@ -865,6 +855,24 @@ class MooncakeConnectorWorker:
|
||||
|
||||
return kv_pulls
|
||||
|
||||
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
|
||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send[req_id]
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
send_meta.ready.set()
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
self.reqs_need_send[req_id] = SendBlockMeta(
|
||||
local_block_ids=[],
|
||||
ready=asyncio.Event(),
|
||||
)
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||
if self.kv_role != "kv_producer":
|
||||
kv_pulls = self.group_kv_pull(metadata)
|
||||
@@ -874,23 +882,9 @@ class MooncakeConnectorWorker:
|
||||
)
|
||||
|
||||
if self.kv_role != "kv_consumer":
|
||||
with self.reqs_need_send.lock:
|
||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||
if block_ids:
|
||||
# Already gone through request_finished()
|
||||
send_meta = self.reqs_need_send.reqs[req_id]
|
||||
send_meta.local_block_ids = block_ids
|
||||
send_meta.ready.set()
|
||||
send_meta.expire_time = (
|
||||
time.perf_counter()
|
||||
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||
)
|
||||
else:
|
||||
# From update_state_after_alloc(),
|
||||
# but not reach request_finished() yet
|
||||
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
|
||||
local_block_ids=[], ready=threading.Event()
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.record_send_reqs(metadata), self.sender_loop
|
||||
)
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
@@ -917,3 +911,8 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
||||
+ vllm_config.parallel_config.data_parallel_index
|
||||
* vllm_config.parallel_config.tensor_parallel_size
|
||||
)
|
||||
|
||||
|
||||
def _async_loop(loop: asyncio.AbstractEventLoop):
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
Reference in New Issue
Block a user