[KVConnector] Support worker -> scheduler metadata (#31964)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -5,21 +5,27 @@ import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.kv_connector.unit.utils import create_vllm_config
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiConnector,
|
||||
MultiKVConnectorStats,
|
||||
MultiKVConnectorWorkerMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
NixlKVConnectorStats,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
@@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats):
|
||||
|
||||
|
||||
class MockConnector(KVConnectorBase_V1):
|
||||
"""Mock connector that implements build_kv_connector_stats for testing."""
|
||||
"""Mock connector for testing."""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# mock all KVConnectorBase_V1 functions
|
||||
mock = MagicMock(spec_set=KVConnectorBase_V1)
|
||||
# Override just build_kv_connector_stats
|
||||
mock.build_kv_connector_stats = cls.build_kv_connector_stats
|
||||
return mock
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
@@ -70,16 +83,42 @@ class MockConnector(KVConnectorBase_V1):
|
||||
pass
|
||||
|
||||
|
||||
class MockCrossLayerConnector(MockConnector):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# Register the mock connector
|
||||
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mc() -> MultiConnector:
|
||||
"""MultiConnector using two mocked connectors"""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
mock_connector_config = {
|
||||
"kv_connector": "MockConnector",
|
||||
"kv_role": "kv_both",
|
||||
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
|
||||
}
|
||||
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [mock_connector_config, mock_connector_config],
|
||||
},
|
||||
)
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
|
||||
)
|
||||
|
||||
mc = MultiConnector(
|
||||
vllm_config=vllm_config,
|
||||
role=KVConnectorRole.WORKER,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
return mc
|
||||
|
||||
|
||||
# Helper function to compare directories recursively
|
||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||
"""Compares two directories recursively for identical content."""
|
||||
@@ -715,24 +754,6 @@ class TestMultiConnectorStats:
|
||||
assert not stats.is_empty()
|
||||
|
||||
|
||||
class TestMultiConnectorPreferCrossLayerBlocks:
|
||||
def test_all_connectors_prefer_cross_layer_blocks(self):
|
||||
mc = MultiConnector.__new__(MultiConnector)
|
||||
mc._connectors = [
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
]
|
||||
assert mc.prefer_cross_layer_blocks is True
|
||||
|
||||
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
|
||||
mc = MultiConnector.__new__(MultiConnector)
|
||||
mc._connectors = [
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
MockConnector.__new__(MockConnector), # default False
|
||||
]
|
||||
assert mc.prefer_cross_layer_blocks is False
|
||||
|
||||
|
||||
def test_multi_connector_overrides_all_base_methods():
|
||||
"""
|
||||
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
|
||||
@@ -767,3 +788,133 @@ Options:
|
||||
1. Add delegation in MultiConnector (preferred)
|
||||
2. Add to INHERITED_OK if the base implementation works correctly
|
||||
""")
|
||||
|
||||
|
||||
def test_multi_connector_prefer_cross_layer_blocks(mc):
|
||||
mc._connectors[0].prefer_cross_layer_blocks = False
|
||||
mc._connectors[1].prefer_cross_layer_blocks = True
|
||||
assert mc.prefer_cross_layer_blocks is False
|
||||
|
||||
mc._connectors[0].prefer_cross_layer_blocks = True
|
||||
mc._connectors[1].prefer_cross_layer_blocks = True
|
||||
assert mc.prefer_cross_layer_blocks is True
|
||||
|
||||
|
||||
def test_multi_connector_worker_metadata(mc):
|
||||
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
|
||||
def __init__(self, data: set[str]):
|
||||
self.data = data
|
||||
|
||||
class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata):
|
||||
def aggregate(
|
||||
self, other: KVConnectorWorkerMetadata
|
||||
) -> KVConnectorWorkerMetadata:
|
||||
assert isinstance(other, MockConnectorWorkerMetadata)
|
||||
return MockConnectorWorkerMetadata0(data=self.data | other.data)
|
||||
|
||||
class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata):
|
||||
def aggregate(
|
||||
self, other: KVConnectorWorkerMetadata
|
||||
) -> KVConnectorWorkerMetadata:
|
||||
assert isinstance(other, MockConnectorWorkerMetadata)
|
||||
return MockConnectorWorkerMetadata1(data=self.data | other.data)
|
||||
|
||||
# -------------------- test build_worker_connector_meta -------------------
|
||||
|
||||
# both connectors return None
|
||||
mc._connectors[0].build_connector_worker_meta.return_value = None
|
||||
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||
assert mc.build_connector_worker_meta() is None
|
||||
|
||||
# only first connector returns None
|
||||
worker_meta1a = MockConnectorWorkerMetadata1({"1a"})
|
||||
mc._connectors[0].build_connector_worker_meta.return_value = None
|
||||
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a
|
||||
mc_worker_meta_none_1a = mc.build_connector_worker_meta()
|
||||
assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata)
|
||||
assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a)
|
||||
|
||||
# only second connector returns None
|
||||
worker_meta0a = MockConnectorWorkerMetadata0({"0a"})
|
||||
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a
|
||||
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||
mc_worker_meta_0a_none = mc.build_connector_worker_meta()
|
||||
assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata)
|
||||
assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None)
|
||||
|
||||
# both connectors do not return None
|
||||
worker_meta0b = MockConnectorWorkerMetadata0({"0b"})
|
||||
worker_meta1b = MockConnectorWorkerMetadata1({"1b"})
|
||||
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
|
||||
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b
|
||||
mc_worker_meta_0b_1b = mc.build_connector_worker_meta()
|
||||
assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata)
|
||||
assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b)
|
||||
|
||||
# ----------------------------- test aggregate ----------------------------
|
||||
|
||||
# aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"})
|
||||
mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a)
|
||||
assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata)
|
||||
assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a)
|
||||
|
||||
# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)
|
||||
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
|
||||
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||
mc_worker_meta_0b_none = mc.build_connector_worker_meta()
|
||||
mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none)
|
||||
assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata)
|
||||
assert mc_worker_meta_0a_0b.metadata[1] is None
|
||||
connector0_md = mc_worker_meta_0a_0b.metadata[0]
|
||||
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
|
||||
assert connector0_md.data == {"0a", "0b"}
|
||||
|
||||
# aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"})
|
||||
mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b)
|
||||
assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata)
|
||||
metadata = mc_worker_meta_01a_01b.metadata
|
||||
assert len(metadata) == 2
|
||||
connector0_md, connector1_md = metadata
|
||||
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
|
||||
assert isinstance(connector1_md, MockConnectorWorkerMetadata1)
|
||||
assert connector0_md.data == {"0a", "0b"}
|
||||
assert connector1_md.data == {"1a", "1b"}
|
||||
|
||||
# ---------------------- test update_connector_output ---------------------
|
||||
|
||||
def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None):
|
||||
def _verify_worker_metadata(connector_output: KVConnectorOutput):
|
||||
worker_meta = connector_output.kv_connector_worker_meta
|
||||
if expected_metadata is None:
|
||||
assert worker_meta is None
|
||||
return
|
||||
|
||||
assert isinstance(worker_meta, MockConnectorWorkerMetadata)
|
||||
assert type(worker_meta) is type(expected_metadata)
|
||||
assert expected_metadata.data == worker_meta.data
|
||||
|
||||
return _verify_worker_metadata
|
||||
|
||||
def assert_update_connector_output_called(mc: MultiConnector):
|
||||
for c in mc._connectors:
|
||||
c.update_connector_output.assert_called_once()
|
||||
c.update_connector_output.reset_mock()
|
||||
|
||||
# no worker meta
|
||||
kv_connector_output = KVConnectorOutput()
|
||||
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None)
|
||||
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None)
|
||||
mc.update_connector_output(kv_connector_output)
|
||||
assert_update_connector_output_called(mc)
|
||||
|
||||
# multi worker meta
|
||||
kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b
|
||||
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(
|
||||
connector0_md
|
||||
)
|
||||
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(
|
||||
connector1_md
|
||||
)
|
||||
mc.update_connector_output(kv_connector_output)
|
||||
assert_update_connector_output_called(mc)
|
||||
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b
|
||||
|
||||
@@ -85,6 +85,7 @@ class KVOutputAggregator:
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
aggregated_kv_connector_stats = None
|
||||
aggregated_kv_connector_worker_meta = None
|
||||
combined_kv_cache_events = None
|
||||
invalid_block_ids = set[int]()
|
||||
for model_runner_output in outputs:
|
||||
@@ -127,6 +128,17 @@ class KVOutputAggregator:
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
)
|
||||
|
||||
# Aggregate kv_connector_worker_meta from all workers.
|
||||
if aggregated_kv_connector_worker_meta is None:
|
||||
# Use the first worker's kv_connector_worker_meta as accumulator.
|
||||
aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta
|
||||
elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta:
|
||||
aggregated_kv_connector_worker_meta = (
|
||||
aggregated_kv_connector_worker_meta.aggregate(
|
||||
kv_connector_worker_meta
|
||||
)
|
||||
)
|
||||
|
||||
# Combine kv_cache_events from all workers.
|
||||
if combined_kv_cache_events is None:
|
||||
# Use the first worker's kv_cache events as start event list.
|
||||
@@ -151,6 +163,7 @@ class KVOutputAggregator:
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
kv_cache_events=combined_kv_cache_events or None,
|
||||
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
|
||||
invalid_block_ids=invalid_block_ids,
|
||||
expected_finished_count=self._expected_finished_count,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,8 @@ The class provides the following primitives:
|
||||
|
||||
get_finished() - called with ids of finished requests, returns
|
||||
ids of requests that have completed async sending/recving.
|
||||
build_connector_worker_meta() - builds metadata to be sent
|
||||
back to the scheduler-side connector
|
||||
"""
|
||||
|
||||
import enum
|
||||
@@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
||||
|
||||
class KVConnectorMetadata(ABC): # noqa: B024
|
||||
"""
|
||||
Abstract Metadata used to communicate between the
|
||||
Scheduler KVConnector and Worker KVConnector.
|
||||
Abstract Metadata used to communicate
|
||||
Scheduler KVConnector -> Worker KVConnector.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorWorkerMetadata(ABC):
|
||||
"""
|
||||
Abstract Metadata used to communicate back
|
||||
Worker KVConnector -> Scheduler KVConnector.
|
||||
|
||||
Each worker can output its own metadata.
|
||||
For a single engine step, all metadata objects returned by workers
|
||||
will be aggregated using the `aggregate` method below, before
|
||||
being passed to the Scheduler KVConnector.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(
|
||||
self, other: "KVConnectorWorkerMetadata"
|
||||
) -> "KVConnectorWorkerMetadata":
|
||||
"""
|
||||
Aggregate metadata with another `KVConnectorWorkerMetadata` object.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
@@ -409,6 +432,16 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
|
||||
"""
|
||||
Build the KVConnector worker metadata for this engine step.
|
||||
|
||||
Returns:
|
||||
KVConnectorWorkerMetadata: the worker metadata.
|
||||
None if no worker metadata is available.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorHandshakeMetadata,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
KVConnectorWorkerMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorPromMetrics,
|
||||
@@ -45,6 +46,26 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
extra_async_saves: dict[str, int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorWorkerMetadata(KVConnectorWorkerMetadata):
|
||||
metadata: tuple[KVConnectorWorkerMetadata | None, ...]
|
||||
|
||||
def aggregate(self, other: KVConnectorWorkerMetadata) -> KVConnectorWorkerMetadata:
|
||||
assert isinstance(other, MultiKVConnectorWorkerMetadata)
|
||||
|
||||
assert len(self.metadata) == len(other.metadata)
|
||||
metadata_list = []
|
||||
for metadata1, metadata2 in zip(self.metadata, other.metadata):
|
||||
if metadata1 is None:
|
||||
metadata_list.append(metadata2)
|
||||
elif metadata2 is None:
|
||||
metadata_list.append(metadata1)
|
||||
else:
|
||||
metadata_list.append(metadata1.aggregate(metadata2))
|
||||
|
||||
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
@@ -304,6 +325,18 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
# Currently no connectors return non-None
|
||||
return None
|
||||
|
||||
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
|
||||
metadata_list: list[KVConnectorWorkerMetadata | None] | None = None
|
||||
for i, c in enumerate(self._connectors):
|
||||
kv_connector_worker_meta = c.build_connector_worker_meta()
|
||||
if metadata_list is None and kv_connector_worker_meta is not None:
|
||||
metadata_list = [None] * i
|
||||
if metadata_list is not None:
|
||||
metadata_list.append(kv_connector_worker_meta)
|
||||
if metadata_list is None:
|
||||
return None
|
||||
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
|
||||
|
||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
|
||||
# method for the MultiConnector. It should be able to get events from
|
||||
# multiple connectors, handling the case where only a subset of the
|
||||
@@ -361,8 +394,25 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
return metadata
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
for c in self._connectors:
|
||||
multi_connector_worker_meta: MultiKVConnectorWorkerMetadata | None = None
|
||||
if connector_output.kv_connector_worker_meta is not None:
|
||||
assert isinstance(
|
||||
connector_output.kv_connector_worker_meta,
|
||||
MultiKVConnectorWorkerMetadata,
|
||||
)
|
||||
multi_connector_worker_meta = connector_output.kv_connector_worker_meta
|
||||
|
||||
try:
|
||||
for i, c in enumerate(self._connectors):
|
||||
if multi_connector_worker_meta is not None:
|
||||
# set the connector-specific worker metadata
|
||||
connector_output.kv_connector_worker_meta = (
|
||||
multi_connector_worker_meta.metadata[i]
|
||||
)
|
||||
c.update_connector_output(connector_output)
|
||||
finally:
|
||||
# restore kv_connector_worker_meta
|
||||
connector_output.kv_connector_worker_meta = multi_connector_worker_meta
|
||||
|
||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||
"""
|
||||
|
||||
@@ -14,9 +14,13 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.distributed.kv_events import KVConnectorKVEvents
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorWorkerMetadata,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
else:
|
||||
KVConnectorStats = object
|
||||
KVConnectorWorkerMetadata = object
|
||||
KVConnectorKVEvents = object
|
||||
|
||||
|
||||
@@ -142,6 +146,7 @@ class KVConnectorOutput:
|
||||
finished_recving: set[str] | None = None
|
||||
kv_connector_stats: KVConnectorStats | None = None
|
||||
kv_cache_events: KVConnectorKVEvents | None = None
|
||||
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None
|
||||
# IDs of externally computed KV blocks that failed to load.
|
||||
# Requests referencing these blocks should be rescheduled to recompute them
|
||||
invalid_block_ids: set[int] = field(default_factory=set)
|
||||
@@ -159,6 +164,7 @@ class KVConnectorOutput:
|
||||
and not self.kv_connector_stats
|
||||
and not self.kv_cache_events
|
||||
and not self.invalid_block_ids
|
||||
and not self.kv_connector_worker_meta
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin:
|
||||
|
||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||
output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta()
|
||||
|
||||
if not defer_finalize:
|
||||
kv_connector.clear_connector_metadata()
|
||||
|
||||
Reference in New Issue
Block a user