[V1][Metrics] add support for kv event publishing (#16750)

Signed-off-by: alec-flowers <aflowers@nvidia.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Alec
2025-04-30 16:44:45 +02:00
committed by GitHub
parent 77073c77bc
commit 0be6d05b5e
15 changed files with 1185 additions and 53 deletions

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional, Union
import msgspec
import msgspec.msgpack
import pytest
import zmq
from vllm.config import KVEventsConfig
from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch
@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)
@pytest.fixture
def publisher_config(random_port, request):
"""Create a publisher config with inproc transport"""
how = request.param if hasattr(request, "param") else "inproc"
if how == "inproc":
endpoint = f"inproc://test-{random_port}"
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"
return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")
@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
yield pub
pub.shutdown()
@pytest.fixture
def subscriber(publisher_config):
"""Create and return a subscriber for testing"""
endpoint = publisher_config.endpoint
replay_endpoint = publisher_config.replay_endpoint
if endpoint.startswith("tcp://*"):
endpoint = endpoint.replace("*", "127.0.0.1")
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
yield sub
sub.close()
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
self.ctx = zmq.Context.instance()
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)
# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)
self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
def receive_one(self,
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
"""Receive a single message with timeout"""
if not self.sub.poll(timeout):
return None
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
assert topic_bytes == self.topic_bytes
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
self.last_seq = seq
self.received_msgs.append((seq, data))
return seq, data
def request_replay(self, start_seq: int) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")
self.replay.send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")
replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
break
frames = self.replay.recv_multipart()
if not frames or not frames[-1]:
# End of replay marker
break
seq_bytes, payload = frames
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
replayed.append((seq, data))
except zmq.ZMQError as _:
break
return replayed
def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()

View File

@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
import threading
import time
import msgspec
import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
class EventSample(
msgspec.Struct,
tag=True, # type: ignore
array_like=True # type: ignore
):
"""Test event for publisher testing"""
id: int
value: str
class SampleBatch(EventBatch):
"""Test event batch for publisher testing"""
events: list[EventSample]
def create_test_events(count: int) -> SampleBatch:
"""Create a batch of test events"""
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
return SampleBatch(ts=time.time(), events=events)
def test_basic_publishing(publisher, subscriber):
"""Test basic event publishing works"""
test_batch = create_test_events(5)
publisher.publish(test_batch)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert received.ts == pytest.approx(test_batch.ts,
abs=0.1), ("Timestamp mismatch")
assert len(received.events) == len(
test_batch.events), ("Number of events mismatch")
for i, event in enumerate(received.events):
assert event.id == i, "Event id mismatch"
assert event.value == f"test-{i}", "Event value mismatch"
def test_multiple_events(publisher, subscriber):
"""Test publishing and receiving multiple event batches"""
for _ in range(10):
batch = create_test_events(2)
publisher.publish(batch)
received = []
for _ in range(10):
data = subscriber.receive_one(timeout=100)
if data:
received.append(data)
assert len(received) == 10, "Number of messages mismatch"
seqs = [seq for seq, _ in received]
assert seqs == list(range(10)), "Sequence numbers mismatch"
def test_replay_mechanism(publisher, subscriber):
"""Test the replay mechanism works correctly"""
for _ in range(19):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(10)
batch = create_test_events(1)
publisher.publish(batch) # 20th message
replayed = subscriber.receive_replay()
assert len(replayed) > 0, "No replayed messages received"
seqs = [seq for seq, _ in replayed]
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
assert seqs == list(range(min(seqs),
max(seqs) +
1)), ("Replayed messages not consecutive")
def test_buffer_limit(publisher, subscriber, publisher_config):
"""Test buffer limit behavior"""
buffer_size = publisher_config.buffer_steps
# Publish more events than the buffer can hold
for i in range(buffer_size + 10):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(0)
batch = create_test_events(1)
publisher.publish(batch)
replayed = subscriber.receive_replay()
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
oldest_seq = min(seq for seq, _ in replayed)
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
def test_topic_filtering(publisher_config):
"""
Test that a subscriber only receives messages matching its topic filter
"""
publisher_config.replay_endpoint = None
cfg = publisher_config.model_copy()
cfg.topic = "foo"
pub = EventPublisherFactory.create(cfg)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
try:
time.sleep(0.1)
for _ in range(3):
pub.publish(create_test_events(1))
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
assert all(msg is not None for msg in foo_received), (
"Subscriber with matching topic should receive messages")
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
assert all(msg is None for msg in bar_received), (
"Subscriber with non-matching topic should receive no messages")
finally:
pub.shutdown()
sub_foo.close()
sub_bar.close()
def test_high_volume(publisher, subscriber):
"""Test publishing and receiving a high volume of events"""
num_batches = 10_000
events_per_batch = 100
# Publish events in a separate thread to not block
def publish_events():
for i in range(num_batches):
batch = create_test_events(events_per_batch)
publisher.publish(batch)
# Small delay to avoid overwhelming
if i % 100 == 0:
time.sleep(0.01)
received: list[tuple[int, SampleBatch]] = []
publisher_thread = threading.Thread(target=publish_events)
publisher_thread.start()
start_time = time.time()
while len(received) < num_batches:
if time.time() - start_time > 10: # Timeout after 10 seconds
break
result = subscriber.receive_one(timeout=100)
if result:
received.append(result)
publisher_thread.join()
assert len(received) >= num_batches * 0.9, (
"We should have received most messages")
seqs = [seq for seq, _ in received]
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
def test_null_publisher():
"""Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher()
# This should not raise any errors
batch = create_test_events(5)
publisher.publish(batch)
publisher.shutdown()