From a1a3523a5647a58e00096ca7430e9f1ad4a50a97 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Wed, 11 Mar 2026 19:36:37 +0200 Subject: [PATCH] [KVConnector] Support worker -> scheduler metadata (#31964) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Or Ozeri Co-authored-by: Nicolò Lucchesi --- .../kv_connector/unit/test_multi_connector.py | 201 +++++++++++++++--- .../kv_transfer/kv_connector/utils.py | 13 ++ .../kv_transfer/kv_connector/v1/base.py | 37 +++- .../kv_connector/v1/multi_connector.py | 54 ++++- vllm/v1/outputs.py | 6 + .../worker/kv_connector_model_runner_mixin.py | 1 + 6 files changed, 283 insertions(+), 29 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 0541dcaa5..6acc48629 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -5,21 +5,27 @@ import shutil import tempfile from pathlib import Path from typing import Any +from unittest.mock import MagicMock import pytest +from tests.v1.kv_connector.unit.utils import create_vllm_config from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiConnector, MultiKVConnectorStats, + MultiKVConnectorWorkerMetadata, ) from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( NixlKVConnectorStats, ) +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats): class MockConnector(KVConnectorBase_V1): - """Mock connector that implements build_kv_connector_stats for testing.""" + """Mock connector for testing.""" + + def __new__(cls, *args, **kwargs): + # mock all KVConnectorBase_V1 functions + mock = MagicMock(spec_set=KVConnectorBase_V1) + # Override just build_kv_connector_stats + mock.build_kv_connector_stats = cls.build_kv_connector_stats + return mock @classmethod def build_kv_connector_stats( @@ -70,16 +83,42 @@ class MockConnector(KVConnectorBase_V1): pass -class MockCrossLayerConnector(MockConnector): - @property - def prefer_cross_layer_blocks(self) -> bool: - return True - - # Register the mock connector KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) +@pytest.fixture +def mc() -> MultiConnector: + """MultiConnector using two mocked connectors""" + vllm_config = create_vllm_config() + + mock_connector_config = { + "kv_connector": "MockConnector", + "kv_role": "kv_both", + "kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector", + } + + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [mock_connector_config, mock_connector_config], + }, + ) + + kv_cache_config = KVCacheConfig( + num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[] + ) + + mc = MultiConnector( + vllm_config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, + ) + + return mc + + # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -715,24 +754,6 @@ class TestMultiConnectorStats: assert not stats.is_empty() -class TestMultiConnectorPreferCrossLayerBlocks: - def test_all_connectors_prefer_cross_layer_blocks(self): - mc = MultiConnector.__new__(MultiConnector) - mc._connectors = [ - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - ] - assert mc.prefer_cross_layer_blocks is True - - def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self): - mc = MultiConnector.__new__(MultiConnector) - mc._connectors = [ - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - MockConnector.__new__(MockConnector), # default False - ] - assert mc.prefer_cross_layer_blocks is False - - def test_multi_connector_overrides_all_base_methods(): """ Ensure MultiConnector overrides all public methods from KVConnectorBase_V1. @@ -767,3 +788,133 @@ Options: 1. Add delegation in MultiConnector (preferred) 2. Add to INHERITED_OK if the base implementation works correctly """) + + +def test_multi_connector_prefer_cross_layer_blocks(mc): + mc._connectors[0].prefer_cross_layer_blocks = False + mc._connectors[1].prefer_cross_layer_blocks = True + assert mc.prefer_cross_layer_blocks is False + + mc._connectors[0].prefer_cross_layer_blocks = True + mc._connectors[1].prefer_cross_layer_blocks = True + assert mc.prefer_cross_layer_blocks is True + + +def test_multi_connector_worker_metadata(mc): + class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata): + def __init__(self, data: set[str]): + self.data = data + + class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata): + def aggregate( + self, other: KVConnectorWorkerMetadata + ) -> KVConnectorWorkerMetadata: + assert isinstance(other, MockConnectorWorkerMetadata) + return MockConnectorWorkerMetadata0(data=self.data | other.data) + + class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata): + def aggregate( + self, other: KVConnectorWorkerMetadata + ) -> KVConnectorWorkerMetadata: + assert isinstance(other, MockConnectorWorkerMetadata) + return MockConnectorWorkerMetadata1(data=self.data | other.data) + + # -------------------- test build_worker_connector_meta ------------------- + + # both connectors return None + mc._connectors[0].build_connector_worker_meta.return_value = None + mc._connectors[1].build_connector_worker_meta.return_value = None + assert mc.build_connector_worker_meta() is None + + # only first connector returns None + worker_meta1a = MockConnectorWorkerMetadata1({"1a"}) + mc._connectors[0].build_connector_worker_meta.return_value = None + mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a + mc_worker_meta_none_1a = mc.build_connector_worker_meta() + assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata) + assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a) + + # only second connector returns None + worker_meta0a = MockConnectorWorkerMetadata0({"0a"}) + mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a + mc._connectors[1].build_connector_worker_meta.return_value = None + mc_worker_meta_0a_none = mc.build_connector_worker_meta() + assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata) + assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None) + + # both connectors do not return None + worker_meta0b = MockConnectorWorkerMetadata0({"0b"}) + worker_meta1b = MockConnectorWorkerMetadata1({"1b"}) + mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b + mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b + mc_worker_meta_0b_1b = mc.build_connector_worker_meta() + assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata) + assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b) + + # ----------------------------- test aggregate ---------------------------- + + # aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"}) + mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a) + assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata) + assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a) + + # aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None) + mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b + mc._connectors[1].build_connector_worker_meta.return_value = None + mc_worker_meta_0b_none = mc.build_connector_worker_meta() + mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none) + assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata) + assert mc_worker_meta_0a_0b.metadata[1] is None + connector0_md = mc_worker_meta_0a_0b.metadata[0] + assert isinstance(connector0_md, MockConnectorWorkerMetadata0) + assert connector0_md.data == {"0a", "0b"} + + # aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"}) + mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b) + assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata) + metadata = mc_worker_meta_01a_01b.metadata + assert len(metadata) == 2 + connector0_md, connector1_md = metadata + assert isinstance(connector0_md, MockConnectorWorkerMetadata0) + assert isinstance(connector1_md, MockConnectorWorkerMetadata1) + assert connector0_md.data == {"0a", "0b"} + assert connector1_md.data == {"1a", "1b"} + + # ---------------------- test update_connector_output --------------------- + + def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None): + def _verify_worker_metadata(connector_output: KVConnectorOutput): + worker_meta = connector_output.kv_connector_worker_meta + if expected_metadata is None: + assert worker_meta is None + return + + assert isinstance(worker_meta, MockConnectorWorkerMetadata) + assert type(worker_meta) is type(expected_metadata) + assert expected_metadata.data == worker_meta.data + + return _verify_worker_metadata + + def assert_update_connector_output_called(mc: MultiConnector): + for c in mc._connectors: + c.update_connector_output.assert_called_once() + c.update_connector_output.reset_mock() + + # no worker meta + kv_connector_output = KVConnectorOutput() + mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None) + mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None) + mc.update_connector_output(kv_connector_output) + assert_update_connector_output_called(mc) + + # multi worker meta + kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b + mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata( + connector0_md + ) + mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata( + connector1_md + ) + mc.update_connector_output(kv_connector_output) + assert_update_connector_output_called(mc) + assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 51487e516..155395e84 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -85,6 +85,7 @@ class KVOutputAggregator: finished_sending = set[str]() finished_recving = set[str]() aggregated_kv_connector_stats = None + aggregated_kv_connector_worker_meta = None combined_kv_cache_events = None invalid_block_ids = set[int]() for model_runner_output in outputs: @@ -127,6 +128,17 @@ class KVOutputAggregator: aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) + # Aggregate kv_connector_worker_meta from all workers. + if aggregated_kv_connector_worker_meta is None: + # Use the first worker's kv_connector_worker_meta as accumulator. + aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta + elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta: + aggregated_kv_connector_worker_meta = ( + aggregated_kv_connector_worker_meta.aggregate( + kv_connector_worker_meta + ) + ) + # Combine kv_cache_events from all workers. if combined_kv_cache_events is None: # Use the first worker's kv_cache events as start event list. @@ -151,6 +163,7 @@ class KVOutputAggregator: finished_recving=finished_recving or None, kv_connector_stats=aggregated_kv_connector_stats or None, kv_cache_events=combined_kv_cache_events or None, + kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None, invalid_block_ids=invalid_block_ids, expected_finished_count=self._expected_finished_count, ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 3d9027adf..2abbe6bf6 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -36,6 +36,8 @@ The class provides the following primitives: get_finished() - called with ids of finished requests, returns ids of requests that have completed async sending/recving. + build_connector_worker_meta() - builds metadata to be sent + back to the scheduler-side connector """ import enum @@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024 class KVConnectorMetadata(ABC): # noqa: B024 """ - Abstract Metadata used to communicate between the - Scheduler KVConnector and Worker KVConnector. + Abstract Metadata used to communicate + Scheduler KVConnector -> Worker KVConnector. """ pass +class KVConnectorWorkerMetadata(ABC): + """ + Abstract Metadata used to communicate back + Worker KVConnector -> Scheduler KVConnector. + + Each worker can output its own metadata. + For a single engine step, all metadata objects returned by workers + will be aggregated using the `aggregate` method below, before + being passed to the Scheduler KVConnector. + """ + + @abstractmethod + def aggregate( + self, other: "KVConnectorWorkerMetadata" + ) -> "KVConnectorWorkerMetadata": + """ + Aggregate metadata with another `KVConnectorWorkerMetadata` object. + """ + pass + + class KVConnectorBase_V1(ABC): """ Base class for KV connectors. @@ -409,6 +432,16 @@ class KVConnectorBase_V1(ABC): """ return None + def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: + """ + Build the KVConnector worker metadata for this engine step. + + Returns: + KVConnectorWorkerMetadata: the worker metadata. + None if no worker metadata is available. + """ + return None + # ============================== # Scheduler-side methods # ============================== diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7052886cd..7cc80129a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -17,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorHandshakeMetadata, KVConnectorMetadata, KVConnectorRole, + KVConnectorWorkerMetadata, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -45,6 +46,26 @@ class MultiKVConnectorMetadata(KVConnectorMetadata): extra_async_saves: dict[str, int] | None = None +@dataclass +class MultiKVConnectorWorkerMetadata(KVConnectorWorkerMetadata): + metadata: tuple[KVConnectorWorkerMetadata | None, ...] + + def aggregate(self, other: KVConnectorWorkerMetadata) -> KVConnectorWorkerMetadata: + assert isinstance(other, MultiKVConnectorWorkerMetadata) + + assert len(self.metadata) == len(other.metadata) + metadata_list = [] + for metadata1, metadata2 in zip(self.metadata, other.metadata): + if metadata1 is None: + metadata_list.append(metadata2) + elif metadata2 is None: + metadata_list.append(metadata1) + else: + metadata_list.append(metadata1.aggregate(metadata2)) + + return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list)) + + @dataclass class MultiKVConnectorStats(KVConnectorStats): """ @@ -304,6 +325,18 @@ class MultiConnector(KVConnectorBase_V1): # Currently no connectors return non-None return None + def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None: + metadata_list: list[KVConnectorWorkerMetadata | None] | None = None + for i, c in enumerate(self._connectors): + kv_connector_worker_meta = c.build_connector_worker_meta() + if metadata_list is None and kv_connector_worker_meta is not None: + metadata_list = [None] * i + if metadata_list is not None: + metadata_list.append(kv_connector_worker_meta) + if metadata_list is None: + return None + return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list)) + # TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' # method for the MultiConnector. It should be able to get events from # multiple connectors, handling the case where only a subset of the @@ -361,8 +394,25 @@ class MultiConnector(KVConnectorBase_V1): return metadata def update_connector_output(self, connector_output: KVConnectorOutput): - for c in self._connectors: - c.update_connector_output(connector_output) + multi_connector_worker_meta: MultiKVConnectorWorkerMetadata | None = None + if connector_output.kv_connector_worker_meta is not None: + assert isinstance( + connector_output.kv_connector_worker_meta, + MultiKVConnectorWorkerMetadata, + ) + multi_connector_worker_meta = connector_output.kv_connector_worker_meta + + try: + for i, c in enumerate(self._connectors): + if multi_connector_worker_meta is not None: + # set the connector-specific worker metadata + connector_output.kv_connector_worker_meta = ( + multi_connector_worker_meta.metadata[i] + ) + c.update_connector_output(connector_output) + finally: + # restore kv_connector_worker_meta + connector_output.kv_connector_worker_meta = multi_connector_worker_meta def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: """ diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 22b06f0e2..8eb58de4f 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -14,9 +14,13 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: from vllm.distributed.kv_events import KVConnectorKVEvents + from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorWorkerMetadata, + ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats else: KVConnectorStats = object + KVConnectorWorkerMetadata = object KVConnectorKVEvents = object @@ -142,6 +146,7 @@ class KVConnectorOutput: finished_recving: set[str] | None = None kv_connector_stats: KVConnectorStats | None = None kv_cache_events: KVConnectorKVEvents | None = None + kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None # IDs of externally computed KV blocks that failed to load. # Requests referencing these blocks should be rescheduled to recompute them invalid_block_ids: set[int] = field(default_factory=set) @@ -159,6 +164,7 @@ class KVConnectorOutput: and not self.kv_connector_stats and not self.kv_cache_events and not self.invalid_block_ids + and not self.kv_connector_worker_meta ) @classmethod diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 338c54c13..2921594a3 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin: output.kv_connector_stats = kv_connector.get_kv_connector_stats() output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() + output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta() if not defer_finalize: kv_connector.clear_connector_metadata()