[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)
Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
@@ -4,7 +4,9 @@
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import msgpack
|
||||
@@ -12,12 +14,25 @@ import zmq
|
||||
from quart import Quart, make_response, request
|
||||
|
||||
count = 0
|
||||
prefill_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
decode_instances: dict[str, str] = {} # http_address: zmq_address
|
||||
prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
|
||||
decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp)
|
||||
|
||||
prefill_cv = threading.Condition()
|
||||
decode_cv = threading.Condition()
|
||||
|
||||
DEFAULT_PING_SECONDS = 5
|
||||
|
||||
|
||||
def _remove_oldest_instances(instances: dict[str, Any]) -> None:
|
||||
oldest_key = next(iter(instances), None)
|
||||
while oldest_key is not None:
|
||||
value = instances[oldest_key]
|
||||
if value[1] > time.time():
|
||||
break
|
||||
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
|
||||
instances.pop(oldest_key, None)
|
||||
oldest_key = next(iter(instances), None)
|
||||
|
||||
|
||||
def _listen_for_register(poller, router_socket):
|
||||
while True:
|
||||
@@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket):
|
||||
global prefill_instances
|
||||
global prefill_cv
|
||||
with prefill_cv:
|
||||
prefill_instances[data["http_address"]] = data["zmq_address"]
|
||||
node = prefill_instances.pop(data["http_address"], None)
|
||||
prefill_instances[data["http_address"]] = (
|
||||
data["zmq_address"],
|
||||
time.time() + DEFAULT_PING_SECONDS,
|
||||
)
|
||||
_remove_oldest_instances(prefill_instances)
|
||||
|
||||
elif data["type"] == "D":
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
decode_instances[data["http_address"]] = data["zmq_address"]
|
||||
node = decode_instances.pop(data["http_address"], None)
|
||||
decode_instances[data["http_address"]] = (
|
||||
data["zmq_address"],
|
||||
time.time() + DEFAULT_PING_SECONDS,
|
||||
)
|
||||
_remove_oldest_instances(decode_instances)
|
||||
else:
|
||||
print(
|
||||
"Unexpected, Received message from %s, data: %s",
|
||||
@@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket):
|
||||
data,
|
||||
)
|
||||
|
||||
if node is None:
|
||||
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
|
||||
|
||||
|
||||
def start_service_discovery(hostname, port):
|
||||
if not hostname:
|
||||
@@ -105,12 +134,14 @@ async def handle_request():
|
||||
with prefill_cv:
|
||||
prefill_list = list(prefill_instances.items())
|
||||
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
|
||||
prefill_zmq_addr = prefill_zmq_addr[0]
|
||||
|
||||
global decode_instances
|
||||
global decode_cv
|
||||
with decode_cv:
|
||||
decode_list = list(decode_instances.items())
|
||||
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
|
||||
decode_zmq_addr = decode_zmq_addr[0]
|
||||
|
||||
print(
|
||||
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
|
||||
|
||||
Reference in New Issue
Block a user