feat: add data parallel rank to KVEventBatch (#18925)
This commit is contained in:
@@ -13,11 +13,13 @@ from vllm.distributed.kv_events import EventPublisherFactory
|
||||
|
||||
from .test_events import SampleBatch
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_port():
|
||||
"""Generate a random port number for testing"""
|
||||
return random.randint(10000, 60000)
|
||||
return random.randint(10000, 59900)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -30,21 +32,23 @@ def publisher_config(random_port, request):
|
||||
replay_endpoint = endpoint + "-replay"
|
||||
else:
|
||||
endpoint = f"tcp://*:{random_port}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 1}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 100}"
|
||||
|
||||
return KVEventsConfig(enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test")
|
||||
return KVEventsConfig(
|
||||
enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(publisher_config):
|
||||
"""Create and return a publisher instance"""
|
||||
pub = EventPublisherFactory.create(publisher_config)
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
yield pub
|
||||
pub.shutdown()
|
||||
|
||||
@@ -60,7 +64,11 @@ def subscriber(publisher_config):
|
||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||
|
||||
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
||||
sub = MockSubscriber(
|
||||
[endpoint],
|
||||
[replay_endpoint] if replay_endpoint else None,
|
||||
publisher_config.topic,
|
||||
)
|
||||
yield sub
|
||||
sub.close()
|
||||
|
||||
@@ -68,26 +76,37 @@ def subscriber(publisher_config):
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(self,
|
||||
pub_endpoint: str,
|
||||
replay_endpoint: Optional[str] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch):
|
||||
def __init__(
|
||||
self,
|
||||
pub_endpoints: Union[str, list[str]],
|
||||
replay_endpoints: Optional[Union[str, list[str]]] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch,
|
||||
):
|
||||
self.ctx = zmq.Context.instance()
|
||||
|
||||
# Set up subscriber socket
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
||||
self.sub.connect(pub_endpoint)
|
||||
# Convert single endpoint to list for consistency
|
||||
if isinstance(pub_endpoints, str):
|
||||
pub_endpoints = [pub_endpoints]
|
||||
if isinstance(replay_endpoints, str):
|
||||
replay_endpoints = [replay_endpoints]
|
||||
|
||||
# Set up replay socket if provided
|
||||
self.replay = None
|
||||
if replay_endpoint:
|
||||
self.replay = self.ctx.socket(zmq.REQ)
|
||||
self.replay.connect(replay_endpoint)
|
||||
# Set up subscriber socket - connect to all endpoints
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode("utf-8"))
|
||||
for endpoint in pub_endpoints:
|
||||
self.sub.connect(endpoint)
|
||||
|
||||
# Set up replay sockets if provided
|
||||
self.replay_sockets = []
|
||||
if replay_endpoints:
|
||||
for replay_endpoint in replay_endpoints:
|
||||
replay = self.ctx.socket(zmq.REQ)
|
||||
replay.connect(replay_endpoint)
|
||||
self.replay_sockets.append(replay)
|
||||
|
||||
self.topic = topic
|
||||
self.topic_bytes = topic.encode('utf-8')
|
||||
self.topic_bytes = topic.encode("utf-8")
|
||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
@@ -107,25 +126,31 @@ class MockSubscriber:
|
||||
self.received_msgs.append((seq, data))
|
||||
return seq, data
|
||||
|
||||
def request_replay(self, start_seq: int) -> None:
|
||||
def request_replay(self, start_seq: int, socket_idx: int = 0) -> None:
|
||||
"""Request replay of messages starting from start_seq"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
self.replay.send(start_seq.to_bytes(8, "big"))
|
||||
self.replay_sockets[socket_idx].send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
def receive_replay(self,
|
||||
socket_idx: int = 0) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages from a specific replay socket"""
|
||||
if not self.replay_sockets:
|
||||
raise ValueError("Replay sockets not initialized")
|
||||
if socket_idx >= len(self.replay_sockets):
|
||||
raise ValueError(f"Invalid socket index {socket_idx}")
|
||||
|
||||
replay_socket = self.replay_sockets[socket_idx]
|
||||
replayed: list[tuple[int, SampleBatch]] = []
|
||||
while True:
|
||||
try:
|
||||
if not self.replay.poll(1000):
|
||||
if not replay_socket.poll(1000):
|
||||
break
|
||||
|
||||
frames = self.replay.recv_multipart()
|
||||
frames = replay_socket.recv_multipart()
|
||||
if not frames or not frames[-1]:
|
||||
# End of replay marker
|
||||
break
|
||||
@@ -142,5 +167,5 @@ class MockSubscriber:
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
if self.replay:
|
||||
self.replay.close()
|
||||
for replay in self.replay_sockets:
|
||||
replay.close()
|
||||
|
||||
Reference in New Issue
Block a user