Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -6,24 +6,29 @@ import time
import msgspec
import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
from vllm.distributed.kv_events import (
EventBatch,
EventPublisherFactory,
NullEventPublisher,
)
DP_RANK = 0
class EventSample(
msgspec.Struct,
tag=True, # type: ignore
array_like=True # type: ignore
msgspec.Struct,
tag=True, # type: ignore
array_like=True, # type: ignore
):
"""Test event for publisher testing"""
id: int
value: str
class SampleBatch(EventBatch):
"""Test event batch for publisher testing"""
events: list[EventSample]
@@ -44,10 +49,8 @@ def test_basic_publishing(publisher, subscriber):
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert received.ts == pytest.approx(test_batch.ts,
abs=0.1), ("Timestamp mismatch")
assert len(received.events) == len(
test_batch.events), ("Number of events mismatch")
assert received.ts == pytest.approx(test_batch.ts, abs=0.1), "Timestamp mismatch"
assert len(received.events) == len(test_batch.events), "Number of events mismatch"
for i, event in enumerate(received.events):
assert event.id == i, "Event id mismatch"
@@ -88,9 +91,9 @@ def test_replay_mechanism(publisher, subscriber):
assert len(replayed) > 0, "No replayed messages received"
seqs = [seq for seq, _ in replayed]
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
assert seqs == list(range(min(seqs),
max(seqs) +
1)), ("Replayed messages not consecutive")
assert seqs == list(range(min(seqs), max(seqs) + 1)), (
"Replayed messages not consecutive"
)
def test_buffer_limit(publisher, subscriber, publisher_config):
@@ -126,6 +129,7 @@ def test_topic_filtering(publisher_config):
pub = EventPublisherFactory.create(publisher_config, DP_RANK)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(publisher_config.endpoint, None, "foo")
sub_bar = MockSubscriber(publisher_config.endpoint, None, "bar")
@@ -137,11 +141,13 @@ def test_topic_filtering(publisher_config):
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
assert all(msg is not None for msg in foo_received), (
"Subscriber with matching topic should receive messages")
"Subscriber with matching topic should receive messages"
)
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
assert all(msg is None for msg in bar_received), (
"Subscriber with non-matching topic should receive no messages")
"Subscriber with non-matching topic should receive no messages"
)
finally:
pub.shutdown()
sub_foo.close()
@@ -178,8 +184,7 @@ def test_high_volume(publisher, subscriber):
publisher_thread.join()
assert len(received) >= num_batches * 0.9, (
"We should have received most messages")
assert len(received) >= num_batches * 0.9, "We should have received most messages"
seqs = [seq for seq, _ in received]
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
@@ -209,13 +214,15 @@ def test_data_parallel_rank_tagging(publisher_config):
# For TCP endpoints: tcp://localhost:5557 -> tcp://localhost:5557, tcp://localhost:5558
expected_endpoint_0 = base_endpoint # rank 0 gets port + 0 = same port
expected_endpoint_1 = base_endpoint.replace(
":5557", ":5558") # rank 1 gets port + 1
":5557", ":5558"
) # rank 1 gets port + 1
else:
# For inproc endpoints: inproc://test -> inproc://test_dp0, inproc://test_dp1
expected_endpoint_0 = base_endpoint # rank 0 gets base
expected_endpoint_1 = base_endpoint + "_dp1" # rank 1 gets _dp1
from .conftest import MockSubscriber
sub_0 = MockSubscriber(expected_endpoint_0, None, publisher_config.topic)
sub_1 = MockSubscriber(expected_endpoint_1, None, publisher_config.topic)
@@ -241,15 +248,15 @@ def test_data_parallel_rank_tagging(publisher_config):
# Verify DP rank tagging
assert received_0.data_parallel_rank == 0, (
f"Expected DP rank 0, got {received_0.data_parallel_rank}")
f"Expected DP rank 0, got {received_0.data_parallel_rank}"
)
assert received_1.data_parallel_rank == 1, (
f"Expected DP rank 1, got {received_1.data_parallel_rank}")
f"Expected DP rank 1, got {received_1.data_parallel_rank}"
)
# Verify event content is correct
assert len(
received_0.events) == 2, "Wrong number of events from rank 0"
assert len(
received_1.events) == 3, "Wrong number of events from rank 1"
assert len(received_0.events) == 2, "Wrong number of events from rank 0"
assert len(received_1.events) == 3, "Wrong number of events from rank 1"
finally:
pub_0.shutdown()