[P/D] Dynamic kv_output_aggregator collect size (#26734)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation():
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
|
||||
# done in MultiprocExecutor.execute_model
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
# Create stats for multiple workers with different transfer patterns
|
||||
worker1_stats = NixlKVConnectorStats()
|
||||
@@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation():
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=3)
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
finished_sending: set[str] | None = None,
|
||||
finished_recving: set[str] | None = None,
|
||||
invalid_block_ids: set[int] | None = None,
|
||||
expected_finished_count: int = 0,
|
||||
):
|
||||
self.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
invalid_block_ids=invalid_block_ids or set(),
|
||||
expected_finished_count=expected_finished_count,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
|
||||
|
||||
|
||||
def test_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
output1 = DummyModelRunnerOutput()
|
||||
output2 = DummyModelRunnerOutput()
|
||||
@@ -85,7 +87,7 @@ def test_aggregate_workers_output():
|
||||
|
||||
|
||||
def test_async_aggregate_workers_output():
|
||||
aggregator = KVOutputAggregator(world_size=2)
|
||||
aggregator = KVOutputAggregator(expected_finished_count=2)
|
||||
|
||||
future1: Future[DummyModelRunnerOutput] = Future()
|
||||
future2: Future[DummyModelRunnerOutput] = Future()
|
||||
@@ -158,3 +160,40 @@ def test_async_aggregate_workers_output():
|
||||
assert aggregated.finished_sending is None
|
||||
assert aggregated.finished_recving == {"req2"}
|
||||
assert aggregated.invalid_block_ids == {3, 4, 5}
|
||||
|
||||
|
||||
def test_aggregate_workers_output_with_expected_finished_count():
|
||||
# We create the aggregator expecting to collect from 4 workers
|
||||
aggregator = KVOutputAggregator(expected_finished_count=4)
|
||||
assert aggregator._expected_finished_count == 4
|
||||
# Some request with default expected finished requests
|
||||
output1 = DummyModelRunnerOutput(finished_sending={"req1"})
|
||||
aggregated = aggregator.aggregate([output1])
|
||||
# still expecting to collect from 4 workers
|
||||
assert aggregator._send_remaining_count["req1"] == 3
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
assert not aggregated.kv_connector_output.finished_recving
|
||||
|
||||
# Workers discover and find that in this setup they only need to
|
||||
# collect from 2
|
||||
output1 = DummyModelRunnerOutput(
|
||||
finished_sending={"req1"}, expected_finished_count=2
|
||||
)
|
||||
output2 = DummyModelRunnerOutput(
|
||||
finished_recving={"req2"}, expected_finished_count=2
|
||||
)
|
||||
output3 = DummyModelRunnerOutput(finished_recving={"req2"})
|
||||
# Req2 only needs 2 acks
|
||||
aggregated = aggregator.aggregate([output1, output2, output3])
|
||||
assert aggregated.kv_connector_output.expected_finished_count == 2
|
||||
|
||||
assert not aggregated.kv_connector_output.finished_sending
|
||||
|
||||
# Req2 is finished
|
||||
assert "req2" not in aggregator._recv_remaining_count
|
||||
assert aggregated.kv_connector_output.finished_recving == {"req2"}
|
||||
|
||||
# Req1 is still waiting for 2 more acks (expected_finished_count has no effect)
|
||||
# NOTE: This is to showcase dynamic update. Workers are responsible for
|
||||
# ensuring "req1" termination in this case
|
||||
assert aggregator._send_remaining_count["req1"] == 2
|
||||
Reference in New Issue
Block a user