[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)

Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
Zhonghua Deng
2025-07-17 13:13:00 +08:00
committed by GitHub
parent 76b494444f
commit 8a4e5c5f3c
4 changed files with 262 additions and 252 deletions

View File

@@ -4,7 +4,9 @@
import os
import socket
import threading
import time
import uuid
from typing import Any
import aiohttp
import msgpack
@@ -12,12 +14,25 @@ import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
DEFAULT_PING_SECONDS = 5
def _remove_oldest_instances(instances: dict[str, Any]) -> None:
oldest_key = next(iter(instances), None)
while oldest_key is not None:
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
def _listen_for_register(poller, router_socket):
while True:
@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket):
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
node = prefill_instances.pop(data["http_address"], None)
prefill_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(prefill_instances)
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
node = decode_instances.pop(data["http_address"], None)
decode_instances[data["http_address"]] = (
data["zmq_address"],
time.time() + DEFAULT_PING_SECONDS,
)
_remove_oldest_instances(decode_instances)
else:
print(
"Unexpected, Received message from %s, data: %s",
@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket):
data,
)
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
def start_service_discovery(hostname, port):
if not hostname:
@@ -105,12 +134,14 @@ async def handle_request():
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "