[P/D] Refactor mooncake connector sender thread using async coroutines (#31573)
Signed-off-by: Tianchen Ding <dtcccc@linux.alibaba.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -83,28 +82,10 @@ class RecvReqMeta:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class SendBlockMeta:
|
class SendBlockMeta:
|
||||||
local_block_ids: list[int]
|
local_block_ids: list[int]
|
||||||
ready: threading.Event
|
ready: asyncio.Event
|
||||||
expire_time: float = float("inf")
|
expire_time: float = float("inf")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SendReqMeta:
|
|
||||||
reqs: dict[ReqId, SendBlockMeta]
|
|
||||||
lock: threading.Lock
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FinishedSendReqSet:
|
|
||||||
set: set[ReqId]
|
|
||||||
lock: threading.Lock
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FinishedReceiveReqSet:
|
|
||||||
set: set[ReqId]
|
|
||||||
lock: asyncio.Lock
|
|
||||||
|
|
||||||
|
|
||||||
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
class MooncakeConnectorMetadata(KVConnectorMetadata):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
self.reqs_to_recv: dict[ReqId, RecvReqMeta] = {}
|
||||||
@@ -437,39 +418,50 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
assert vllm_config.kv_transfer_config
|
assert vllm_config.kv_transfer_config
|
||||||
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
self.kv_role = vllm_config.kv_transfer_config.kv_role
|
||||||
self.num_workers = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
self.num_sender_workers = (
|
||||||
"num_workers", 10
|
vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"num_workers", 10
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
# Create more tasks than workers to keep the thread pool saturated.
|
||||||
|
# Tasks can await async events, so a surplus (2x is a robust heuristic)
|
||||||
|
# prevents workers from idling.
|
||||||
|
self.num_sender_tasks = self.num_sender_workers * 2
|
||||||
|
|
||||||
self.kv_caches_base_addr: list[int] = []
|
self.kv_caches_base_addr: list[int] = []
|
||||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||||
self.reqs_need_send: SendReqMeta = SendReqMeta(reqs={}, lock=threading.Lock())
|
self.reqs_need_send: dict[ReqId, SendBlockMeta] = {}
|
||||||
|
|
||||||
# For kv_both, we will act both prefiller and decoder.
|
# For kv_both, we will act both prefiller and decoder.
|
||||||
if self.kv_role != "kv_consumer":
|
if self.kv_role != "kv_consumer":
|
||||||
# Background thread for sending kvcaches to D.
|
# Background threads for sending kvcaches to D.
|
||||||
self._mooncake_sender_t: threading.Thread | None = None
|
|
||||||
# Background thread for processing new sending requests.
|
|
||||||
self._sender_executor = ThreadPoolExecutor(
|
self._sender_executor = ThreadPoolExecutor(
|
||||||
max_workers=self.num_workers, thread_name_prefix="vllm-mooncake-sender"
|
max_workers=self.num_sender_workers,
|
||||||
|
thread_name_prefix="vllm-mooncake-sender",
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Mooncake Prefiller: use %d workers to send kvcaches", self.num_workers
|
"Mooncake Prefiller: use %d workers to send kvcaches",
|
||||||
|
self.num_sender_workers,
|
||||||
)
|
)
|
||||||
|
# An asyncio queue to buffer incoming requests for the sender
|
||||||
|
self.sender_worker_queue = asyncio.Queue[tuple[bytes, bytes]]()
|
||||||
|
self.sender_loop = asyncio.new_event_loop()
|
||||||
|
# Background thread for processing new sending requests.
|
||||||
|
self._sender_listener_t = threading.Thread(
|
||||||
|
target=_async_loop, args=(self.sender_loop,), daemon=True
|
||||||
|
)
|
||||||
|
self._sender_listener_t.start()
|
||||||
|
|
||||||
if self.kv_role != "kv_producer":
|
if self.kv_role != "kv_producer":
|
||||||
self.receiver_loop = asyncio.new_event_loop()
|
self.receiver_loop = asyncio.new_event_loop()
|
||||||
self._mooncake_receiver_t = threading.Thread(
|
self._mooncake_receiver_t = threading.Thread(
|
||||||
target=self._receiver_loop, args=(self.receiver_loop,), daemon=True
|
target=_async_loop, args=(self.receiver_loop,), daemon=True
|
||||||
)
|
)
|
||||||
self._mooncake_receiver_t.start()
|
self._mooncake_receiver_t.start()
|
||||||
logger.debug("Mooncake Decoder: start receiver thread")
|
logger.debug("Mooncake Decoder: start receiver thread")
|
||||||
|
|
||||||
self.finished_sending_reqs: FinishedSendReqSet = FinishedSendReqSet(
|
self.finished_sending_reqs: set[ReqId] = set()
|
||||||
set(), threading.Lock()
|
self.finished_recving_reqs: set[ReqId] = set()
|
||||||
)
|
|
||||||
self.finished_recving_reqs: FinishedReceiveReqSet = FinishedReceiveReqSet(
|
|
||||||
set(), asyncio.Lock()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
@@ -500,7 +492,6 @@ class MooncakeConnectorWorker:
|
|||||||
attn_backend=backend,
|
attn_backend=backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.zmq_ctx = zmq.Context()
|
|
||||||
self.async_zmq_ctx = zmq.asyncio.Context()
|
self.async_zmq_ctx = zmq.asyncio.Context()
|
||||||
self._encoder = msgspec.msgpack.Encoder()
|
self._encoder = msgspec.msgpack.Encoder()
|
||||||
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
self._decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||||
@@ -510,21 +501,17 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Cleanup background threads on destruction."""
|
"""Cleanup background threads on destruction."""
|
||||||
self.zmq_ctx.term()
|
|
||||||
self.async_zmq_ctx.term()
|
self.async_zmq_ctx.term()
|
||||||
if self.kv_role != "kv_consumer":
|
if self.kv_role != "kv_consumer":
|
||||||
self._sender_executor.shutdown(wait=False)
|
self._sender_executor.shutdown(wait=False)
|
||||||
if self._mooncake_sender_t:
|
if self.sender_loop.is_running():
|
||||||
self._mooncake_sender_t.join()
|
self.sender_loop.call_soon_threadsafe(self.sender_loop.stop)
|
||||||
|
self._sender_listener_t.join()
|
||||||
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
if self.kv_role != "kv_producer" and self.receiver_loop.is_running():
|
||||||
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
self.receiver_loop.call_soon_threadsafe(self.receiver_loop.stop)
|
||||||
self._mooncake_receiver_t.join()
|
self._mooncake_receiver_t.join()
|
||||||
|
|
||||||
def _receiver_loop(self, loop: asyncio.AbstractEventLoop):
|
async def _mooncake_sender_listener(
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
loop.run_forever()
|
|
||||||
|
|
||||||
def _mooncake_sender(
|
|
||||||
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
self, ready_event: threading.Event, base_port: int, tp_rank: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -532,93 +519,86 @@ class MooncakeConnectorWorker:
|
|||||||
to a thread pool, and sends acknowledgments upon completion.
|
to a thread pool, and sends acknowledgments upon completion.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frontend_path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
path = make_zmq_path("tcp", self.hostname, base_port + tp_rank)
|
||||||
frontend = make_zmq_socket(self.zmq_ctx, frontend_path, zmq.ROUTER)
|
sock = make_zmq_socket(self.async_zmq_ctx, path, zmq.ROUTER)
|
||||||
logger.debug("Mooncake sender starting listening on path: %s", frontend_path)
|
logger.debug("Mooncake sender starting listening on path: %s", path)
|
||||||
|
|
||||||
backend_path = make_zmq_path("inproc", str(uuid.uuid4()))
|
# Create async worker tasks that process items from the queue
|
||||||
backend = make_zmq_socket(self.zmq_ctx, backend_path, zmq.PULL)
|
sender_tasks = [
|
||||||
|
asyncio.create_task(self._sender_worker(sock))
|
||||||
poller = zmq.Poller()
|
for _ in range(self.num_sender_tasks)
|
||||||
poller.register(frontend, zmq.POLLIN)
|
]
|
||||||
poller.register(backend, zmq.POLLIN)
|
|
||||||
|
|
||||||
ready_event.set()
|
ready_event.set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
sockets = dict(poller.poll())
|
identity, _, metadata_bytes = await sock.recv_multipart()
|
||||||
|
await self.sender_worker_queue.put((identity, metadata_bytes))
|
||||||
if frontend in sockets:
|
|
||||||
identity, _, metadata_bytes = frontend.recv_multipart()
|
|
||||||
self._sender_executor.submit(
|
|
||||||
self._sender_worker,
|
|
||||||
identity,
|
|
||||||
metadata_bytes,
|
|
||||||
backend_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if backend in sockets:
|
|
||||||
identity, status = backend.recv_multipart()
|
|
||||||
frontend.send_multipart((identity, b"", status))
|
|
||||||
|
|
||||||
except zmq.ContextTerminated:
|
except zmq.ContextTerminated:
|
||||||
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
logger.debug("ZMQ context terminated, exiting Mooncake sender thread.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
logger.error("Error in Mooncake sender thread: %s. Exiting thread.", str(e))
|
||||||
finally:
|
finally:
|
||||||
frontend.close()
|
# Clean up worker tasks
|
||||||
backend.close()
|
for task in sender_tasks:
|
||||||
|
task.cancel()
|
||||||
|
await asyncio.gather(*sender_tasks, return_exceptions=True)
|
||||||
|
sock.close()
|
||||||
|
|
||||||
def _sender_worker(
|
async def _sender_worker(self, sock: zmq.asyncio.Socket):
|
||||||
self, identity: bytes, metadata_bytes: bytes, worker_channel_path: str
|
while True:
|
||||||
):
|
|
||||||
status = TRANS_ERROR
|
|
||||||
|
|
||||||
try:
|
|
||||||
metadata = self._decoder.decode(metadata_bytes)
|
|
||||||
self.send_kv_to_decode(metadata)
|
|
||||||
status = TRANS_DONE
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error processing Mooncake handshake: %s", e)
|
|
||||||
finally:
|
|
||||||
pusher = make_zmq_socket(self.zmq_ctx, worker_channel_path, zmq.PUSH)
|
|
||||||
try:
|
try:
|
||||||
pusher.send_multipart((identity, status))
|
identity, metadata_bytes = await self.sender_worker_queue.get()
|
||||||
except zmq.ZMQError as e:
|
try:
|
||||||
logger.warning(
|
metadata = self._decoder.decode(metadata_bytes)
|
||||||
"Internal error, maybe the server is shutting down. Error: %s",
|
await self.send_kv_to_decode(metadata)
|
||||||
e,
|
await sock.send_multipart((identity, b"", TRANS_DONE))
|
||||||
)
|
except Exception as e:
|
||||||
finally:
|
logger.error("Error processing Mooncake xfer request: %s", e)
|
||||||
pusher.close()
|
await sock.send_multipart((identity, b"", TRANS_ERROR))
|
||||||
|
finally:
|
||||||
|
self.sender_worker_queue.task_done()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in _sender_worker: %s", e)
|
||||||
|
|
||||||
def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
async def send_kv_to_decode(self, meta: MooncakeAgentMetadata):
|
||||||
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
send_reqs: list[tuple[ReqId, SendBlockMeta]] = []
|
||||||
with self.reqs_need_send.lock:
|
for req_id in meta.request_ids:
|
||||||
for req_id in meta.request_ids:
|
send_meta = self.reqs_need_send.get(req_id)
|
||||||
send_meta = self.reqs_need_send.reqs.get(req_id)
|
if send_meta is None:
|
||||||
if send_meta is None:
|
logger.warning("Request %s not found in reqs_need_send", req_id)
|
||||||
logger.warning("Request %s not found in reqs_need_send", req_id)
|
return
|
||||||
return
|
# Mark it as not expired. We will send it now.
|
||||||
# Mark it as not expired. We will send it now.
|
send_meta.expire_time = float("inf")
|
||||||
send_meta.expire_time = float("inf")
|
send_reqs.append((req_id, send_meta))
|
||||||
send_reqs.append((req_id, send_meta))
|
|
||||||
|
|
||||||
self._send_blocks(send_reqs, meta)
|
src_ptrs, dst_ptrs, lengths = await self._build_transfer_params(send_reqs, meta)
|
||||||
|
remote_session = f"{meta.remote_hostname}:{meta.remote_port}"
|
||||||
|
ret_value = await self.sender_loop.run_in_executor(
|
||||||
|
self._sender_executor,
|
||||||
|
self._send_blocks,
|
||||||
|
remote_session,
|
||||||
|
src_ptrs,
|
||||||
|
dst_ptrs,
|
||||||
|
lengths,
|
||||||
|
)
|
||||||
|
|
||||||
with self.reqs_need_send.lock:
|
if ret_value != 0:
|
||||||
for req_id in meta.request_ids:
|
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
||||||
del self.reqs_need_send.reqs[req_id]
|
|
||||||
|
|
||||||
with self.finished_sending_reqs.lock:
|
for req_id in meta.request_ids:
|
||||||
self.finished_sending_reqs.set.update(meta.request_ids)
|
del self.reqs_need_send[req_id]
|
||||||
|
|
||||||
def _send_blocks(
|
self.finished_sending_reqs.update(meta.request_ids)
|
||||||
|
|
||||||
|
async def _build_transfer_params(
|
||||||
self,
|
self,
|
||||||
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
send_reqs: list[tuple[ReqId, SendBlockMeta]],
|
||||||
agent_meta: MooncakeAgentMetadata,
|
agent_meta: MooncakeAgentMetadata,
|
||||||
):
|
) -> tuple[list[int], list[int], list[int]]:
|
||||||
src_ptrs = []
|
src_ptrs = []
|
||||||
dst_ptrs = []
|
dst_ptrs = []
|
||||||
lengths = []
|
lengths = []
|
||||||
@@ -631,7 +611,7 @@ class MooncakeConnectorWorker:
|
|||||||
for (req_id, send_meta), remote_block_ids in zip(
|
for (req_id, send_meta), remote_block_ids in zip(
|
||||||
send_reqs, agent_meta.block_ids
|
send_reqs, agent_meta.block_ids
|
||||||
):
|
):
|
||||||
send_meta.ready.wait()
|
await send_meta.ready.wait()
|
||||||
|
|
||||||
num_remote_blocks = len(remote_block_ids)
|
num_remote_blocks = len(remote_block_ids)
|
||||||
if num_remote_blocks == 0:
|
if num_remote_blocks == 0:
|
||||||
@@ -670,18 +650,26 @@ class MooncakeConnectorWorker:
|
|||||||
remote_session,
|
remote_session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return src_ptrs, dst_ptrs, lengths
|
||||||
|
|
||||||
|
def _send_blocks(
|
||||||
|
self,
|
||||||
|
remote_session: str,
|
||||||
|
src_ptrs: list[int],
|
||||||
|
dst_ptrs: list[int],
|
||||||
|
lengths: list[int],
|
||||||
|
) -> int:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ret_value = self.engine.batch_transfer_sync_write(
|
ret_value = self.engine.batch_transfer_sync_write(
|
||||||
remote_session, src_ptrs, dst_ptrs, lengths
|
remote_session, src_ptrs, dst_ptrs, lengths
|
||||||
)
|
)
|
||||||
if ret_value != 0:
|
if ret_value == 0:
|
||||||
raise RuntimeError(f"Error in batch_transfer_sync_write: {ret_value}")
|
logger.debug(
|
||||||
|
"Sending to %s done, took %s",
|
||||||
logger.debug(
|
remote_session,
|
||||||
"Sending to %s done, took %s",
|
time.perf_counter() - start_time,
|
||||||
remote_session,
|
)
|
||||||
time.perf_counter() - start_time,
|
return ret_value
|
||||||
)
|
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
"""Register the KV Cache data in mooncake."""
|
"""Register the KV Cache data in mooncake."""
|
||||||
@@ -740,41 +728,63 @@ class MooncakeConnectorWorker:
|
|||||||
return
|
return
|
||||||
|
|
||||||
ready_event = threading.Event()
|
ready_event = threading.Event()
|
||||||
self._mooncake_sender_t = threading.Thread(
|
asyncio.run_coroutine_threadsafe(
|
||||||
target=self._mooncake_sender,
|
self._mooncake_sender_listener(
|
||||||
args=(ready_event, self.side_channel_port, self.tp_rank),
|
ready_event, self.side_channel_port, self.tp_rank
|
||||||
daemon=True,
|
),
|
||||||
name="mooncake_sender",
|
self.sender_loop,
|
||||||
)
|
)
|
||||||
self._mooncake_sender_t.start()
|
|
||||||
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
ready_event.wait() # Wait for listener ZMQ socket to be ready.
|
||||||
|
|
||||||
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
async def fetch_finished_recving_reqs(self) -> set[ReqId]:
|
||||||
async with self.finished_recving_reqs.lock:
|
finished_recving_reqs = self.finished_recving_reqs
|
||||||
finished_recving_reqs = self.finished_recving_reqs.set
|
self.finished_recving_reqs = set()
|
||||||
self.finished_recving_reqs.set = set()
|
|
||||||
return finished_recving_reqs
|
return finished_recving_reqs
|
||||||
|
|
||||||
|
async def fetch_finished_sending_reqs(self) -> set[ReqId]:
|
||||||
|
finished_sending_reqs = self.finished_sending_reqs
|
||||||
|
self.finished_sending_reqs = set()
|
||||||
|
|
||||||
|
# Handle timeout to avoid stranding blocks on remote.
|
||||||
|
now = time.perf_counter()
|
||||||
|
expired_reqs = [
|
||||||
|
req_id
|
||||||
|
for req_id, send_meta in self.reqs_need_send.items()
|
||||||
|
if send_meta.expire_time < now
|
||||||
|
]
|
||||||
|
for req_id in expired_reqs:
|
||||||
|
logger.warning(
|
||||||
|
"Request %s timed out after %d seconds without "
|
||||||
|
"being sent. Freeing its blocks on the producer side.",
|
||||||
|
req_id,
|
||||||
|
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
||||||
|
)
|
||||||
|
del self.reqs_need_send[req_id]
|
||||||
|
if expired_reqs:
|
||||||
|
finished_sending_reqs.update(expired_reqs)
|
||||||
|
|
||||||
|
return finished_sending_reqs
|
||||||
|
|
||||||
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
def get_finished(self) -> tuple[set[str] | None, set[str] | None]:
|
||||||
"""
|
"""
|
||||||
Get requests that are done sending or recving on this specific worker.
|
Get requests that are done sending or recving on this specific worker.
|
||||||
The scheduler process (via the MultiprocExecutor) will use this output
|
The scheduler process (via the MultiprocExecutor) will use this output
|
||||||
to track which workers are done.
|
to track which workers are done.
|
||||||
"""
|
"""
|
||||||
fut = None
|
recv_fut = None
|
||||||
|
send_fut = None
|
||||||
if self.kv_role != "kv_producer":
|
if self.kv_role != "kv_producer":
|
||||||
fut = asyncio.run_coroutine_threadsafe(
|
recv_fut = asyncio.run_coroutine_threadsafe(
|
||||||
self.fetch_finished_recving_reqs(), self.receiver_loop
|
self.fetch_finished_recving_reqs(), self.receiver_loop
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_role != "kv_consumer":
|
if self.kv_role != "kv_consumer":
|
||||||
with self.finished_sending_reqs.lock:
|
send_fut = asyncio.run_coroutine_threadsafe(
|
||||||
finished_sending_reqs = self.finished_sending_reqs.set
|
self.fetch_finished_sending_reqs(), self.sender_loop
|
||||||
self.finished_sending_reqs.set = set()
|
)
|
||||||
else:
|
|
||||||
finished_sending_reqs = set()
|
|
||||||
|
|
||||||
finished_recving_reqs = fut.result() if fut else set()
|
finished_recving_reqs = recv_fut.result() if recv_fut else set()
|
||||||
|
finished_sending_reqs = send_fut.result() if send_fut else set()
|
||||||
|
|
||||||
if finished_sending_reqs or finished_recving_reqs:
|
if finished_sending_reqs or finished_recving_reqs:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -785,25 +795,6 @@ class MooncakeConnectorWorker:
|
|||||||
len(finished_recving_reqs),
|
len(finished_recving_reqs),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle timeout to avoid stranding blocks on remote.
|
|
||||||
now = time.perf_counter()
|
|
||||||
with self.reqs_need_send.lock:
|
|
||||||
expired_reqs = [
|
|
||||||
req_id
|
|
||||||
for req_id, send_meta in self.reqs_need_send.reqs.items()
|
|
||||||
if send_meta.expire_time < now
|
|
||||||
]
|
|
||||||
for req_id in expired_reqs:
|
|
||||||
logger.warning(
|
|
||||||
"Request %s timed out after %d seconds without "
|
|
||||||
"being sent. Freeing its blocks on the producer side.",
|
|
||||||
req_id,
|
|
||||||
envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT,
|
|
||||||
)
|
|
||||||
del self.reqs_need_send.reqs[req_id]
|
|
||||||
if expired_reqs:
|
|
||||||
finished_sending_reqs.update(expired_reqs)
|
|
||||||
|
|
||||||
return finished_sending_reqs or None, finished_recving_reqs or None
|
return finished_sending_reqs or None, finished_recving_reqs or None
|
||||||
|
|
||||||
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
async def receive_kv(self, path: str, req_blocks: list[tuple[str, list[int]]]):
|
||||||
@@ -844,8 +835,7 @@ class MooncakeConnectorWorker:
|
|||||||
finally:
|
finally:
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
async with self.finished_recving_reqs.lock:
|
self.finished_recving_reqs.update(req_ids)
|
||||||
self.finished_recving_reqs.set.update(req_ids)
|
|
||||||
|
|
||||||
logger.debug("pulling kv_caches for %s finished", req_ids)
|
logger.debug("pulling kv_caches for %s finished", req_ids)
|
||||||
|
|
||||||
@@ -865,6 +855,24 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
return kv_pulls
|
return kv_pulls
|
||||||
|
|
||||||
|
async def record_send_reqs(self, metadata: MooncakeConnectorMetadata):
|
||||||
|
for req_id, block_ids in metadata.reqs_to_send.items():
|
||||||
|
if block_ids:
|
||||||
|
# Already gone through request_finished()
|
||||||
|
send_meta = self.reqs_need_send[req_id]
|
||||||
|
send_meta.local_block_ids = block_ids
|
||||||
|
send_meta.expire_time = (
|
||||||
|
time.perf_counter() + envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
||||||
|
)
|
||||||
|
send_meta.ready.set()
|
||||||
|
else:
|
||||||
|
# From update_state_after_alloc(),
|
||||||
|
# but not reach request_finished() yet
|
||||||
|
self.reqs_need_send[req_id] = SendBlockMeta(
|
||||||
|
local_block_ids=[],
|
||||||
|
ready=asyncio.Event(),
|
||||||
|
)
|
||||||
|
|
||||||
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
|
||||||
if self.kv_role != "kv_producer":
|
if self.kv_role != "kv_producer":
|
||||||
kv_pulls = self.group_kv_pull(metadata)
|
kv_pulls = self.group_kv_pull(metadata)
|
||||||
@@ -874,23 +882,9 @@ class MooncakeConnectorWorker:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.kv_role != "kv_consumer":
|
if self.kv_role != "kv_consumer":
|
||||||
with self.reqs_need_send.lock:
|
asyncio.run_coroutine_threadsafe(
|
||||||
for req_id, block_ids in metadata.reqs_to_send.items():
|
self.record_send_reqs(metadata), self.sender_loop
|
||||||
if block_ids:
|
)
|
||||||
# Already gone through request_finished()
|
|
||||||
send_meta = self.reqs_need_send.reqs[req_id]
|
|
||||||
send_meta.local_block_ids = block_ids
|
|
||||||
send_meta.ready.set()
|
|
||||||
send_meta.expire_time = (
|
|
||||||
time.perf_counter()
|
|
||||||
+ envs.VLLM_MOONCAKE_ABORT_REQUEST_TIMEOUT
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# From update_state_after_alloc(),
|
|
||||||
# but not reach request_finished() yet
|
|
||||||
self.reqs_need_send.reqs[req_id] = SendBlockMeta(
|
|
||||||
local_block_ids=[], ready=threading.Event()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def group_concurrent_contiguous(
|
def group_concurrent_contiguous(
|
||||||
@@ -917,3 +911,8 @@ def get_mooncake_side_channel_port(vllm_config: VllmConfig) -> int:
|
|||||||
+ vllm_config.parallel_config.data_parallel_index
|
+ vllm_config.parallel_config.data_parallel_index
|
||||||
* vllm_config.parallel_config.tensor_parallel_size
|
* vllm_config.parallel_config.tensor_parallel_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _async_loop(loop: asyncio.AbstractEventLoop):
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_forever()
|
||||||
|
|||||||
Reference in New Issue
Block a user