[KVConnector] Support worker -> scheduler metadata (#31964)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Or Ozeri
2026-03-11 19:36:37 +02:00
committed by GitHub
parent 741f4e046b
commit a1a3523a56
6 changed files with 283 additions and 29 deletions

View File

@@ -5,21 +5,27 @@ import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import MagicMock
import pytest import pytest
from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector, MultiConnector,
MultiKVConnectorStats, MultiKVConnectorStats,
MultiKVConnectorWorkerMetadata,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats, NixlKVConnectorStats,
) )
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats):
class MockConnector(KVConnectorBase_V1): class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing.""" """Mock connector for testing."""
def __new__(cls, *args, **kwargs):
# mock all KVConnectorBase_V1 functions
mock = MagicMock(spec_set=KVConnectorBase_V1)
# Override just build_kv_connector_stats
mock.build_kv_connector_stats = cls.build_kv_connector_stats
return mock
@classmethod @classmethod
def build_kv_connector_stats( def build_kv_connector_stats(
@@ -70,16 +83,42 @@ class MockConnector(KVConnectorBase_V1):
pass pass
class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
# Register the mock connector # Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
@pytest.fixture
def mc() -> MultiConnector:
"""MultiConnector using two mocked connectors"""
vllm_config = create_vllm_config()
mock_connector_config = {
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
}
vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [mock_connector_config, mock_connector_config],
},
)
kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)
mc = MultiConnector(
vllm_config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
return mc
# Helper function to compare directories recursively # Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool: def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content.""" """Compares two directories recursively for identical content."""
@@ -715,24 +754,6 @@ class TestMultiConnectorStats:
assert not stats.is_empty() assert not stats.is_empty()
class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
def test_multi_connector_overrides_all_base_methods(): def test_multi_connector_overrides_all_base_methods():
""" """
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1. Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
@@ -767,3 +788,133 @@ Options:
1. Add delegation in MultiConnector (preferred) 1. Add delegation in MultiConnector (preferred)
2. Add to INHERITED_OK if the base implementation works correctly 2. Add to INHERITED_OK if the base implementation works correctly
""") """)
def test_multi_connector_prefer_cross_layer_blocks(mc):
mc._connectors[0].prefer_cross_layer_blocks = False
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is False
mc._connectors[0].prefer_cross_layer_blocks = True
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is True
def test_multi_connector_worker_metadata(mc):
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
def __init__(self, data: set[str]):
self.data = data
class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata0(data=self.data | other.data)
class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata1(data=self.data | other.data)
# -------------------- test build_worker_connector_meta -------------------
# both connectors return None
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = None
assert mc.build_connector_worker_meta() is None
# only first connector returns None
worker_meta1a = MockConnectorWorkerMetadata1({"1a"})
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a
mc_worker_meta_none_1a = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a)
# only second connector returns None
worker_meta0a = MockConnectorWorkerMetadata0({"0a"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0a_none = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None)
# both connectors do not return None
worker_meta0b = MockConnectorWorkerMetadata0({"0b"})
worker_meta1b = MockConnectorWorkerMetadata1({"1b"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b
mc_worker_meta_0b_1b = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b)
# ----------------------------- test aggregate ----------------------------
# aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"})
mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a)
assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a)
# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0b_none = mc.build_connector_worker_meta()
mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none)
assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_0b.metadata[1] is None
connector0_md = mc_worker_meta_0a_0b.metadata[0]
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert connector0_md.data == {"0a", "0b"}
# aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"})
mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b)
assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata)
metadata = mc_worker_meta_01a_01b.metadata
assert len(metadata) == 2
connector0_md, connector1_md = metadata
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert isinstance(connector1_md, MockConnectorWorkerMetadata1)
assert connector0_md.data == {"0a", "0b"}
assert connector1_md.data == {"1a", "1b"}
# ---------------------- test update_connector_output ---------------------
def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None):
def _verify_worker_metadata(connector_output: KVConnectorOutput):
worker_meta = connector_output.kv_connector_worker_meta
if expected_metadata is None:
assert worker_meta is None
return
assert isinstance(worker_meta, MockConnectorWorkerMetadata)
assert type(worker_meta) is type(expected_metadata)
assert expected_metadata.data == worker_meta.data
return _verify_worker_metadata
def assert_update_connector_output_called(mc: MultiConnector):
for c in mc._connectors:
c.update_connector_output.assert_called_once()
c.update_connector_output.reset_mock()
# no worker meta
kv_connector_output = KVConnectorOutput()
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
# multi worker meta
kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(
connector0_md
)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(
connector1_md
)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b

View File

@@ -85,6 +85,7 @@ class KVOutputAggregator:
finished_sending = set[str]() finished_sending = set[str]()
finished_recving = set[str]() finished_recving = set[str]()
aggregated_kv_connector_stats = None aggregated_kv_connector_stats = None
aggregated_kv_connector_worker_meta = None
combined_kv_cache_events = None combined_kv_cache_events = None
invalid_block_ids = set[int]() invalid_block_ids = set[int]()
for model_runner_output in outputs: for model_runner_output in outputs:
@@ -127,6 +128,17 @@ class KVOutputAggregator:
aggregated_kv_connector_stats.aggregate(kv_connector_stats) aggregated_kv_connector_stats.aggregate(kv_connector_stats)
) )
# Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None:
# Use the first worker's kv_connector_worker_meta as accumulator.
aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta
elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta:
aggregated_kv_connector_worker_meta = (
aggregated_kv_connector_worker_meta.aggregate(
kv_connector_worker_meta
)
)
# Combine kv_cache_events from all workers. # Combine kv_cache_events from all workers.
if combined_kv_cache_events is None: if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list. # Use the first worker's kv_cache events as start event list.
@@ -151,6 +163,7 @@ class KVOutputAggregator:
finished_recving=finished_recving or None, finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None, kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None, kv_cache_events=combined_kv_cache_events or None,
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
invalid_block_ids=invalid_block_ids, invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count, expected_finished_count=self._expected_finished_count,
) )

View File

@@ -36,6 +36,8 @@ The class provides the following primitives:
get_finished() - called with ids of finished requests, returns get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving. ids of requests that have completed async sending/recving.
build_connector_worker_meta() - builds metadata to be sent
back to the scheduler-side connector
""" """
import enum import enum
@@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024
class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorMetadata(ABC): # noqa: B024
""" """
Abstract Metadata used to communicate between the Abstract Metadata used to communicate
Scheduler KVConnector and Worker KVConnector. Scheduler KVConnector -> Worker KVConnector.
""" """
pass pass
class KVConnectorWorkerMetadata(ABC):
"""
Abstract Metadata used to communicate back
Worker KVConnector -> Scheduler KVConnector.
Each worker can output its own metadata.
For a single engine step, all metadata objects returned by workers
will be aggregated using the `aggregate` method below, before
being passed to the Scheduler KVConnector.
"""
@abstractmethod
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
"""
Aggregate metadata with another `KVConnectorWorkerMetadata` object.
"""
pass
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
""" """
Base class for KV connectors. Base class for KV connectors.
@@ -409,6 +432,16 @@ class KVConnectorBase_V1(ABC):
""" """
return None return None
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
"""
Build the KVConnector worker metadata for this engine step.
Returns:
KVConnectorWorkerMetadata: the worker metadata.
None if no worker metadata is available.
"""
return None
# ============================== # ==============================
# Scheduler-side methods # Scheduler-side methods
# ============================== # ==============================

View File

@@ -17,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorHandshakeMetadata, KVConnectorHandshakeMetadata,
KVConnectorMetadata, KVConnectorMetadata,
KVConnectorRole, KVConnectorRole,
KVConnectorWorkerMetadata,
) )
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics, KVConnectorPromMetrics,
@@ -45,6 +46,26 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves: dict[str, int] | None = None extra_async_saves: dict[str, int] | None = None
@dataclass
class MultiKVConnectorWorkerMetadata(KVConnectorWorkerMetadata):
metadata: tuple[KVConnectorWorkerMetadata | None, ...]
def aggregate(self, other: KVConnectorWorkerMetadata) -> KVConnectorWorkerMetadata:
assert isinstance(other, MultiKVConnectorWorkerMetadata)
assert len(self.metadata) == len(other.metadata)
metadata_list = []
for metadata1, metadata2 in zip(self.metadata, other.metadata):
if metadata1 is None:
metadata_list.append(metadata2)
elif metadata2 is None:
metadata_list.append(metadata1)
else:
metadata_list.append(metadata1.aggregate(metadata2))
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
@dataclass @dataclass
class MultiKVConnectorStats(KVConnectorStats): class MultiKVConnectorStats(KVConnectorStats):
""" """
@@ -304,6 +325,18 @@ class MultiConnector(KVConnectorBase_V1):
# Currently no connectors return non-None # Currently no connectors return non-None
return None return None
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
metadata_list: list[KVConnectorWorkerMetadata | None] | None = None
for i, c in enumerate(self._connectors):
kv_connector_worker_meta = c.build_connector_worker_meta()
if metadata_list is None and kv_connector_worker_meta is not None:
metadata_list = [None] * i
if metadata_list is not None:
metadata_list.append(kv_connector_worker_meta)
if metadata_list is None:
return None
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' # TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
# method for the MultiConnector. It should be able to get events from # method for the MultiConnector. It should be able to get events from
# multiple connectors, handling the case where only a subset of the # multiple connectors, handling the case where only a subset of the
@@ -361,8 +394,25 @@ class MultiConnector(KVConnectorBase_V1):
return metadata return metadata
def update_connector_output(self, connector_output: KVConnectorOutput): def update_connector_output(self, connector_output: KVConnectorOutput):
for c in self._connectors: multi_connector_worker_meta: MultiKVConnectorWorkerMetadata | None = None
if connector_output.kv_connector_worker_meta is not None:
assert isinstance(
connector_output.kv_connector_worker_meta,
MultiKVConnectorWorkerMetadata,
)
multi_connector_worker_meta = connector_output.kv_connector_worker_meta
try:
for i, c in enumerate(self._connectors):
if multi_connector_worker_meta is not None:
# set the connector-specific worker metadata
connector_output.kv_connector_worker_meta = (
multi_connector_worker_meta.metadata[i]
)
c.update_connector_output(connector_output) c.update_connector_output(connector_output)
finally:
# restore kv_connector_worker_meta
connector_output.kv_connector_worker_meta = multi_connector_worker_meta
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None: def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
""" """

View File

@@ -14,9 +14,13 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_events import KVConnectorKVEvents from vllm.distributed.kv_events import KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else: else:
KVConnectorStats = object KVConnectorStats = object
KVConnectorWorkerMetadata = object
KVConnectorKVEvents = object KVConnectorKVEvents = object
@@ -142,6 +146,7 @@ class KVConnectorOutput:
finished_recving: set[str] | None = None finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None kv_connector_stats: KVConnectorStats | None = None
kv_cache_events: KVConnectorKVEvents | None = None kv_cache_events: KVConnectorKVEvents | None = None
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None
# IDs of externally computed KV blocks that failed to load. # IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them # Requests referencing these blocks should be rescheduled to recompute them
invalid_block_ids: set[int] = field(default_factory=set) invalid_block_ids: set[int] = field(default_factory=set)
@@ -159,6 +164,7 @@ class KVConnectorOutput:
and not self.kv_connector_stats and not self.kv_connector_stats
and not self.kv_cache_events and not self.kv_cache_events
and not self.invalid_block_ids and not self.invalid_block_ids
and not self.kv_connector_worker_meta
) )
@classmethod @classmethod

View File

@@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin:
output.kv_connector_stats = kv_connector.get_kv_connector_stats() output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events() output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta()
if not defer_finalize: if not defer_finalize:
kv_connector.clear_connector_metadata() kv_connector.clear_connector_metadata()