Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,24 +6,29 @@ import time
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||
NullEventPublisher)
|
||||
from vllm.distributed.kv_events import (
|
||||
EventBatch,
|
||||
EventPublisherFactory,
|
||||
NullEventPublisher,
|
||||
)
|
||||
|
||||
DP_RANK = 0
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True # type: ignore
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True, # type: ignore
|
||||
):
|
||||
"""Test event for publisher testing"""
|
||||
|
||||
id: int
|
||||
value: str
|
||||
|
||||
|
||||
class SampleBatch(EventBatch):
|
||||
"""Test event batch for publisher testing"""
|
||||
|
||||
events: list[EventSample]
|
||||
|
||||
|
||||
@@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
|
||||
|
||||
seq, received = result
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert received.ts == pytest.approx(test_batch.ts,
|
||||
abs=0.1), ("Timestamp mismatch")
|
||||
assert len(received.events) == len(
|
||||
test_batch.events), ("Number of events mismatch")
|
||||
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
|
||||
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
|
||||
|
||||
for i, event in enumerate(received.events):
|
||||
assert event.id == i, "Event id mismatch"
|
||||
@@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
|
||||
assert len(replayed) > 0, "No replayed messages received"
|
||||
seqs = [seq for seq, _ in replayed]
|
||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||
assert seqs == list(range(min(seqs),
|
||||
max(seqs) +
|
||||
1)), ("Replayed messages not consecutive")
|
||||
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
|
||||
"Replayed messages not consecutive"
|
||||
)
|
||||
|
||||
|
||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||
@@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
|
||||
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
|
||||
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
|
||||
|
||||
@@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
|
||||
|
||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is not None for msg in foo_received), (
|
||||
"Subscriber with matching topic should receive messages")
|
||||
"Subscriber with matching topic should receive messages"
|
||||
)
|
||||
|
||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is None for msg in bar_received), (
|
||||
"Subscriber with non-matching topic should receive no messages")
|
||||
"Subscriber with non-matching topic should receive no messages"
|
||||
)
|
||||
finally:
|
||||
pub.shutdown()
|
||||
sub_foo.close()
|
||||
@@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
|
||||
|
||||
publisher_thread.join()
|
||||
|
||||
assert len(received) >= num_batches * 0.9, (
|
||||
"We should have received most messages")
|
||||
assert len(received) >= num_batches * 0.9, "We should have received most messages"
|
||||
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||
@@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
||||
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
|
||||
expected_endpoint_1 = base_endpoint.replace(
|
||||
":5557", ":5558") # rank 1 gets port + 1
|
||||
":5557", ":5558"
|
||||
) # rank 1 gets port + 1
|
||||
else:
|
||||
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
|
||||
expected_endpoint_0 = base_endpoint # rank 0 gets base
|
||||
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
|
||||
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
|
||||
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
|
||||
|
||||
@@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
|
||||
|
||||
# Verify DP rank tagging
|
||||
assert received_0.data_parallel_rank == 0, (
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
|
||||
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
|
||||
)
|
||||
assert received_1.data_parallel_rank == 1, (
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
|
||||
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
|
||||
)
|
||||
|
||||
# Verify event content is correct
|
||||
assert len(
|
||||
received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(
|
||||
received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
|
||||
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
|
||||
|
||||
finally:
|
||||
pub_0.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user