diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index bb77c4f2b..e073321c6 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -703,7 +703,7 @@ def test_kv_connector_stats_aggregation(): # Create KVOutputAggregator for 3 workers (simulating TP=3), same thing # done in MultiprocExecutor.execute_model - aggregator = KVOutputAggregator(world_size=3) + aggregator = KVOutputAggregator(expected_finished_count=3) # Create stats for multiple workers with different transfer patterns worker1_stats = NixlKVConnectorStats() @@ -768,7 +768,7 @@ def test_multi_kv_connector_stats_aggregation(): KVOutputAggregator (used by MultiprocExecutor). """ - aggregator = KVOutputAggregator(world_size=3) + aggregator = KVOutputAggregator(expected_finished_count=3) from dataclasses import dataclass diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggregator.py similarity index 73% rename from tests/v1/kv_connector/unit/test_output_aggreagator.py rename to tests/v1/kv_connector/unit/test_output_aggregator.py index 2635b256b..4dba203eb 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggregator.py @@ -16,11 +16,13 @@ class DummyModelRunnerOutput(ModelRunnerOutput): finished_sending: set[str] | None = None, finished_recving: set[str] | None = None, invalid_block_ids: set[int] | None = None, + expected_finished_count: int = 0, ): self.kv_connector_output = KVConnectorOutput( finished_sending=finished_sending, finished_recving=finished_recving, invalid_block_ids=invalid_block_ids or set(), + expected_finished_count=expected_finished_count, ) def __repr__(self): @@ -33,7 +35,7 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def test_aggregate_workers_output(): - aggregator = KVOutputAggregator(world_size=2) + aggregator = KVOutputAggregator(expected_finished_count=2) output1 = DummyModelRunnerOutput() output2 = DummyModelRunnerOutput() @@ -85,7 +87,7 @@ def test_aggregate_workers_output(): def test_async_aggregate_workers_output(): - aggregator = KVOutputAggregator(world_size=2) + aggregator = KVOutputAggregator(expected_finished_count=2) future1: Future[DummyModelRunnerOutput] = Future() future2: Future[DummyModelRunnerOutput] = Future() @@ -158,3 +160,40 @@ def test_async_aggregate_workers_output(): assert aggregated.finished_sending is None assert aggregated.finished_recving == {"req2"} assert aggregated.invalid_block_ids == {3, 4, 5} + + +def test_aggregate_workers_output_with_expected_finished_count(): + # We create the aggregator expecting to collect from 4 workers + aggregator = KVOutputAggregator(expected_finished_count=4) + assert aggregator._expected_finished_count == 4 + # Some request with default expected finished requests + output1 = DummyModelRunnerOutput(finished_sending={"req1"}) + aggregated = aggregator.aggregate([output1]) + # still expecting to collect from 4 workers + assert aggregator._send_remaining_count["req1"] == 3 + assert not aggregated.kv_connector_output.finished_sending + assert not aggregated.kv_connector_output.finished_recving + + # Workers discover and find that in this setup they only need to + # collect from 2 + output1 = DummyModelRunnerOutput( + finished_sending={"req1"}, expected_finished_count=2 + ) + output2 = DummyModelRunnerOutput( + finished_recving={"req2"}, expected_finished_count=2 + ) + output3 = DummyModelRunnerOutput(finished_recving={"req2"}) + # Req2 only needs 2 acks + aggregated = aggregator.aggregate([output1, output2, output3]) + assert aggregated.kv_connector_output.expected_finished_count == 2 + + assert not aggregated.kv_connector_output.finished_sending + + # Req2 is finished + assert "req2" not in aggregator._recv_remaining_count + assert aggregated.kv_connector_output.finished_recving == {"req2"} + + # Req1 is still waiting for 2 more acks (expected_finished_count has no effect) + # NOTE: This is to showcase dynamic update. Workers are responsible for + # ensuring "req1" termination in this case + assert aggregator._send_remaining_count["req1"] == 2 diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0fe678b9c..22af489a8 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -4,10 +4,9 @@ KV cache helper for store. """ -from collections import defaultdict from collections.abc import Sequence from concurrent.futures import CancelledError, Future -from typing import Literal, cast +from typing import TYPE_CHECKING, Literal, cast import torch @@ -18,6 +17,9 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + logger = init_logger(__name__) @@ -124,11 +126,16 @@ class KVOutputAggregator: """Utility class to aggregate the output of all workers into a single output corresponding to Rank 0 for scheduler.""" - def __init__(self, world_size: int): + def __init__(self, expected_finished_count: int): # Complete transfer tracker. Used to track finished requests # [req_id -> n_remaining_workers] - self._recv_remaining_count = defaultdict[str, int](lambda: world_size) - self._send_remaining_count = defaultdict[str, int](lambda: world_size) + self._recv_remaining_count = dict[str, int]() + self._send_remaining_count = dict[str, int]() + self._expected_finished_count = expected_finished_count + + @classmethod + def from_connector(cls, connector: "KVConnectorBase", world_size: int): + return cls(connector.get_finished_count() or world_size) def aggregate( self, outputs: list[ModelRunnerOutput], output_rank: int = 0 @@ -141,7 +148,10 @@ class KVOutputAggregator: finished_set: set[str], ) -> None: for req_id in req_ids or (): - remaining_count_dict[req_id] -= 1 + remaining_count = remaining_count_dict.get( + req_id, self._expected_finished_count + ) + remaining_count_dict[req_id] = remaining_count - 1 if remaining_count_dict[req_id] == 0: finished_set.add(req_id) del remaining_count_dict[req_id] @@ -154,6 +164,19 @@ class KVOutputAggregator: kv_output = model_runner_output.kv_connector_output if not kv_output: continue + # Allow the worker to dynamically update the expected number of + # finished sending/recving for new requests. + if ( + kv_output.expected_finished_count > 0 + and kv_output.expected_finished_count != self._expected_finished_count + ): + logger.debug( + "Expected finished requests updated from %d to %d", + self._expected_finished_count, + kv_output.expected_finished_count, + ) + self._expected_finished_count = kv_output.expected_finished_count + update_finished_set( kv_output.finished_sending, self._send_remaining_count, finished_sending ) @@ -186,6 +209,7 @@ class KVOutputAggregator: finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, invalid_block_ids=invalid_block_ids, + expected_finished_count=self._expected_finished_count, ) return output diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ab5d2ecdc..989e2f664 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -413,7 +413,8 @@ class KVConnectorBase_V1(ABC): def get_finished_count(self) -> int | None: """ Get the count of requests expected to complete send/receive operations - via this connector. + via this connector. This method is used to initialize the + KVOutputAggregator, overwriting the default world_size. Returns: int: expected sending or receiving completion count. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 00d3821bc..27cf2fbe8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -160,9 +160,7 @@ class EngineCore: ) self.use_spec_decode = vllm_config.speculative_config is not None if self.scheduler.connector is not None: # type: ignore - self.model_executor.init_kv_output_aggregator( - self.scheduler.connector.get_finished_count() # type: ignore - ) + self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_receiver_cache = engine_receiver_cache_from_config( diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 609a681dc..9fe1912c7 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from concurrent.futures import Future from functools import cached_property -from typing import Literal, TypeVar, overload +from typing import TYPE_CHECKING, Literal, TypeVar, overload from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator @@ -19,6 +19,9 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase +if TYPE_CHECKING: + from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase + logger = init_logger(__name__) _R = TypeVar("_R") @@ -233,10 +236,10 @@ class Executor(ABC): """Shutdown the executor.""" self.collective_rpc("shutdown") - def init_kv_output_aggregator(self, finished_count: int | None) -> None: + def init_kv_output_aggregator(self, connector: "KVConnectorBase") -> None: """Init KVOutputAggregator""" - self.kv_output_aggregator = KVOutputAggregator( - finished_count or self.parallel_config.world_size + self.kv_output_aggregator = KVOutputAggregator.from_connector( + connector, self.parallel_config.world_size ) @cached_property # Avoid unnecessary RPC calls diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index c224555da..7eef0ca0a 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -86,8 +86,14 @@ class KVConnectorOutput: finished_recving: set[str] | None = None kv_connector_stats: KVConnectorStats | None = None # IDs of externally computed KV blocks that failed to load. - # Requests referencing these blocks should be rescheduled to recompute them. + # Requests referencing these blocks should be rescheduled to recompute them invalid_block_ids: set[int] = field(default_factory=set) + # Configuration describing how many finished sending/receiving + # notifications should be expected for each request. This allows + # handshake-based connectors like Nixl to update the KVOutputAggregator. + # It captures a static setup info and should almost always remain constant + # for a given connector after discovery. Default value entails no change. + expected_finished_count: int = 0 def is_empty(self): return (