[KVConnector] Support worker -> scheduler metadata (#31964)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Or Ozeri
2026-03-11 19:36:37 +02:00
committed by GitHub
parent 741f4e046b
commit a1a3523a56
6 changed files with 283 additions and 29 deletions

View File

@@ -5,21 +5,27 @@ import shutil
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorStats,
MultiKVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats):
class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing."""
"""Mock connector for testing."""
def __new__(cls, *args, **kwargs):
# mock all KVConnectorBase_V1 functions
mock = MagicMock(spec_set=KVConnectorBase_V1)
# Override just build_kv_connector_stats
mock.build_kv_connector_stats = cls.build_kv_connector_stats
return mock
@classmethod
def build_kv_connector_stats(
@@ -70,16 +83,42 @@ class MockConnector(KVConnectorBase_V1):
pass
class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
@pytest.fixture
def mc() -> MultiConnector:
"""MultiConnector using two mocked connectors"""
vllm_config = create_vllm_config()
mock_connector_config = {
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
}
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [mock_connector_config, mock_connector_config],
},
)
kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)
mc = MultiConnector(
vllm_config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
return mc
# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
@@ -715,24 +754,6 @@ class TestMultiConnectorStats:
assert not stats.is_empty()
class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
def test_multi_connector_overrides_all_base_methods():
"""
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
@@ -767,3 +788,133 @@ Options:
1. Add delegation in MultiConnector (preferred)
2. Add to INHERITED_OK if the base implementation works correctly
""")
def test_multi_connector_prefer_cross_layer_blocks(mc):
mc._connectors[0].prefer_cross_layer_blocks = False
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is False
mc._connectors[0].prefer_cross_layer_blocks = True
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is True
def test_multi_connector_worker_metadata(mc):
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
def __init__(self, data: set[str]):
self.data = data
class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata0(data=self.data | other.data)
class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata1(data=self.data | other.data)
# -------------------- test build_worker_connector_meta -------------------
# both connectors return None
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = None
assert mc.build_connector_worker_meta() is None
# only first connector returns None
worker_meta1a = MockConnectorWorkerMetadata1({"1a"})
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a
mc_worker_meta_none_1a = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a)
# only second connector returns None
worker_meta0a = MockConnectorWorkerMetadata0({"0a"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0a_none = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None)
# both connectors do not return None
worker_meta0b = MockConnectorWorkerMetadata0({"0b"})
worker_meta1b = MockConnectorWorkerMetadata1({"1b"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b
mc_worker_meta_0b_1b = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b)
# ----------------------------- test aggregate ----------------------------
# aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"})
mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a)
assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a)
# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0b_none = mc.build_connector_worker_meta()
mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none)
assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_0b.metadata[1] is None
connector0_md = mc_worker_meta_0a_0b.metadata[0]
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert connector0_md.data == {"0a", "0b"}
# aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"})
mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b)
assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata)
metadata = mc_worker_meta_01a_01b.metadata
assert len(metadata) == 2
connector0_md, connector1_md = metadata
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert isinstance(connector1_md, MockConnectorWorkerMetadata1)
assert connector0_md.data == {"0a", "0b"}
assert connector1_md.data == {"1a", "1b"}
# ---------------------- test update_connector_output ---------------------
def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None):
def _verify_worker_metadata(connector_output: KVConnectorOutput):
worker_meta = connector_output.kv_connector_worker_meta
if expected_metadata is None:
assert worker_meta is None
return
assert isinstance(worker_meta, MockConnectorWorkerMetadata)
assert type(worker_meta) is type(expected_metadata)
assert expected_metadata.data == worker_meta.data
return _verify_worker_metadata
def assert_update_connector_output_called(mc: MultiConnector):
for c in mc._connectors:
c.update_connector_output.assert_called_once()
c.update_connector_output.reset_mock()
# no worker meta
kv_connector_output = KVConnectorOutput()
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
# multi worker meta
kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(
connector0_md
)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(
connector1_md
)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b

View File

@@ -85,6 +85,7 @@ class KVOutputAggregator:
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
aggregated_kv_connector_worker_meta = None
combined_kv_cache_events = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
@@ -127,6 +128,17 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)
# Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None:
# Use the first worker's kv_connector_worker_meta as accumulator.
aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta
elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta:
aggregated_kv_connector_worker_meta = (
aggregated_kv_connector_worker_meta.aggregate(
kv_connector_worker_meta
)
)
# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
@@ -151,6 +163,7 @@ class KVOutputAggregator:
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)

View File

@@ -36,6 +36,8 @@ The class provides the following primitives:
get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
build_connector_worker_meta() - builds metadata to be sent
back to the scheduler-side connector
"""
import enum
@@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024
class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
Abstract Metadata used to communicate
Scheduler KVConnector -> Worker KVConnector.
"""
pass
class KVConnectorWorkerMetadata(ABC):
"""
Abstract Metadata used to communicate back
Worker KVConnector -> Scheduler KVConnector.
Each worker can output its own metadata.
For a single engine step, all metadata objects returned by workers
will be aggregated using the `aggregate` method below, before
being passed to the Scheduler KVConnector.
"""
@abstractmethod
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
"""
Aggregate metadata with another `KVConnectorWorkerMetadata` object.
"""
pass
class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.
@@ -409,6 +432,16 @@ class KVConnectorBase_V1(ABC):
"""
return None
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
"""
Build the KVConnector worker metadata for this engine step.
Returns:
KVConnectorWorkerMetadata: the worker metadata.
None if no worker metadata is available.
"""
return None
# ==============================
# Scheduler-side methods
# ==============================

View File

@@ -17,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata,
KVConnectorMetadata,
KVConnectorRole,
KVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
@@ -45,6 +46,26 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves: dict[str, int] | None = None
@dataclass
class MultiKVConnectorWorkerMetadata(KVConnectorWorkerMetadata):
metadata: tuple[KVConnectorWorkerMetadata | None, ...]
def aggregate(self, other: KVConnectorWorkerMetadata) -> KVConnectorWorkerMetadata:
assert isinstance(other, MultiKVConnectorWorkerMetadata)
assert len(self.metadata) == len(other.metadata)
metadata_list = []
for metadata1, metadata2 in zip(self.metadata, other.metadata):
if metadata1 is None:
metadata_list.append(metadata2)
elif metadata2 is None:
metadata_list.append(metadata1)
else:
metadata_list.append(metadata1.aggregate(metadata2))
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
@dataclass
class MultiKVConnectorStats(KVConnectorStats):
"""
@@ -304,6 +325,18 @@ class MultiConnector(KVConnectorBase_V1):
# Currently no connectors return non-None
return None
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
metadata_list: list[KVConnectorWorkerMetadata | None] | None = None
for i, c in enumerate(self._connectors):
kv_connector_worker_meta = c.build_connector_worker_meta()
if metadata_list is None and kv_connector_worker_meta is not None:
metadata_list = [None] * i
if metadata_list is not None:
metadata_list.append(kv_connector_worker_meta)
if metadata_list is None:
return None
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
# method for the MultiConnector. It should be able to get events from
# multiple connectors, handling the case where only a subset of the
@@ -361,8 +394,25 @@ class MultiConnector(KVConnectorBase_V1):
return metadata
def update_connector_output(self, connector_output: KVConnectorOutput):
for c in self._connectors:
multi_connector_worker_meta: MultiKVConnectorWorkerMetadata | None = None
if connector_output.kv_connector_worker_meta is not None:
assert isinstance(
connector_output.kv_connector_worker_meta,
MultiKVConnectorWorkerMetadata,
)
multi_connector_worker_meta = connector_output.kv_connector_worker_meta
try:
for i, c in enumerate(self._connectors):
if multi_connector_worker_meta is not None:
# set the connector-specific worker metadata
connector_output.kv_connector_worker_meta = (
multi_connector_worker_meta.metadata[i]
)
c.update_connector_output(connector_output)
finally:
# restore kv_connector_worker_meta
connector_output.kv_connector_worker_meta = multi_connector_worker_meta
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""

View File

@@ -14,9 +14,13 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.distributed.kv_events import KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else:
KVConnectorStats = object
KVConnectorWorkerMetadata = object
KVConnectorKVEvents = object
@@ -142,6 +146,7 @@ class KVConnectorOutput:
finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None
kv_cache_events: KVConnectorKVEvents | None = None
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set)
@@ -159,6 +164,7 @@ class KVConnectorOutput:
and not self.kv_connector_stats
and not self.kv_cache_events
and not self.invalid_block_ids
and not self.kv_connector_worker_meta
)
@classmethod

View File

@@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin:
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta()
if not defer_finalize:
kv_connector.clear_connector_metadata()