[P/D] Dynamic kv_output_aggregator collect size (#26734)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-10-22 18:07:58 +02:00
committed by GitHub
parent 58fab50d82
commit 4dfdb821c8
7 changed files with 90 additions and 19 deletions

View File

@@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator = KVOutputAggregator(world_size=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
# Create stats for multiple workers with different transfer patterns
worker1_stats = NixlKVConnectorStats()
@@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator = KVOutputAggregator(world_size=3)
aggregator = KVOutputAggregator(expected_finished_count=3)
from dataclasses import dataclass

View File

@@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
finished_sending: set[str] | None = None,
finished_recving: set[str] | None = None,
invalid_block_ids: set[int] | None = None,
expected_finished_count: int = 0,
):
self.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
invalid_block_ids=invalid_block_ids or set(),
expected_finished_count=expected_finished_count,
)
def __repr__(self):
@@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def test_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2)
aggregator = KVOutputAggregator(expected_finished_count=2)
output1 = DummyModelRunnerOutput()
output2 = DummyModelRunnerOutput()
@@ -85,7 +87,7 @@ def test_aggregate_workers_output():
def test_async_aggregate_workers_output():
aggregator = KVOutputAggregator(world_size=2)
aggregator = KVOutputAggregator(expected_finished_count=2)
future1: Future[DummyModelRunnerOutput] = Future()
future2: Future[DummyModelRunnerOutput] = Future()
@@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
assert aggregated.finished_sending is None
assert aggregated.finished_recving == {"req2"}
assert aggregated.invalid_block_ids == {3, 4, 5}
def test_aggregate_workers_output_with_expected_finished_count():
# We create the aggregator expecting to collect from 4 workers
aggregator = KVOutputAggregator(expected_finished_count=4)
assert aggregator._expected_finished_count == 4
# Some request with default expected finished requests
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
aggregated = aggregator.aggregate([output1])
# still expecting to collect from 4 workers
assert aggregator._send_remaining_count["req1"] == 3
assert not aggregated.kv_connector_output.finished_sending
assert not aggregated.kv_connector_output.finished_recving
# Workers discover and find that in this setup they only need to
# collect from 2
output1 = DummyModelRunnerOutput(
finished_sending={"req1"}, expected_finished_count=2
)
output2 = DummyModelRunnerOutput(
finished_recving={"req2"}, expected_finished_count=2
)
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
# Req2 only needs 2 acks
aggregated = aggregator.aggregate([output1, output2, output3])
assert aggregated.kv_connector_output.expected_finished_count == 2
assert not aggregated.kv_connector_output.finished_sending
# Req2 is finished
assert "req2" not in aggregator._recv_remaining_count
assert aggregated.kv_connector_output.finished_recving == {"req2"}
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
# NOTE: This is to showcase dynamic update. Workers are responsible for
# ensuring "req1" termination in this case
assert aggregator._send_remaining_count["req1"] == 2