[KVConnector] Support worker -> scheduler metadata (#31964)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -5,21 +5,27 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.v1.kv_connector.unit.utils import create_vllm_config
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||||
MultiConnector,
|
MultiConnector,
|
||||||
MultiKVConnectorStats,
|
MultiKVConnectorStats,
|
||||||
|
MultiKVConnectorWorkerMetadata,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||||
NixlKVConnectorStats,
|
NixlKVConnectorStats,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
@@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats):
|
|||||||
|
|
||||||
|
|
||||||
class MockConnector(KVConnectorBase_V1):
|
class MockConnector(KVConnectorBase_V1):
|
||||||
"""Mock connector that implements build_kv_connector_stats for testing."""
|
"""Mock connector for testing."""
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# mock all KVConnectorBase_V1 functions
|
||||||
|
mock = MagicMock(spec_set=KVConnectorBase_V1)
|
||||||
|
# Override just build_kv_connector_stats
|
||||||
|
mock.build_kv_connector_stats = cls.build_kv_connector_stats
|
||||||
|
return mock
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_kv_connector_stats(
|
def build_kv_connector_stats(
|
||||||
@@ -70,16 +83,42 @@ class MockConnector(KVConnectorBase_V1):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MockCrossLayerConnector(MockConnector):
|
|
||||||
@property
|
|
||||||
def prefer_cross_layer_blocks(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# Register the mock connector
|
# Register the mock connector
|
||||||
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mc() -> MultiConnector:
|
||||||
|
"""MultiConnector using two mocked connectors"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
mock_connector_config = {
|
||||||
|
"kv_connector": "MockConnector",
|
||||||
|
"kv_role": "kv_both",
|
||||||
|
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
|
||||||
|
}
|
||||||
|
|
||||||
|
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="MultiConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={
|
||||||
|
"connectors": [mock_connector_config, mock_connector_config],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
|
||||||
|
)
|
||||||
|
|
||||||
|
mc = MultiConnector(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
role=KVConnectorRole.WORKER,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mc
|
||||||
|
|
||||||
|
|
||||||
# Helper function to compare directories recursively
|
# Helper function to compare directories recursively
|
||||||
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
def _compare_directories(dir1: Path, dir2: Path) -> bool:
|
||||||
"""Compares two directories recursively for identical content."""
|
"""Compares two directories recursively for identical content."""
|
||||||
@@ -715,24 +754,6 @@ class TestMultiConnectorStats:
|
|||||||
assert not stats.is_empty()
|
assert not stats.is_empty()
|
||||||
|
|
||||||
|
|
||||||
class TestMultiConnectorPreferCrossLayerBlocks:
|
|
||||||
def test_all_connectors_prefer_cross_layer_blocks(self):
|
|
||||||
mc = MultiConnector.__new__(MultiConnector)
|
|
||||||
mc._connectors = [
|
|
||||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
|
||||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
|
||||||
]
|
|
||||||
assert mc.prefer_cross_layer_blocks is True
|
|
||||||
|
|
||||||
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
|
|
||||||
mc = MultiConnector.__new__(MultiConnector)
|
|
||||||
mc._connectors = [
|
|
||||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
|
||||||
MockConnector.__new__(MockConnector), # default False
|
|
||||||
]
|
|
||||||
assert mc.prefer_cross_layer_blocks is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_connector_overrides_all_base_methods():
|
def test_multi_connector_overrides_all_base_methods():
|
||||||
"""
|
"""
|
||||||
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
|
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
|
||||||
@@ -767,3 +788,133 @@ Options:
|
|||||||
1. Add delegation in MultiConnector (preferred)
|
1. Add delegation in MultiConnector (preferred)
|
||||||
2. Add to INHERITED_OK if the base implementation works correctly
|
2. Add to INHERITED_OK if the base implementation works correctly
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_connector_prefer_cross_layer_blocks(mc):
|
||||||
|
mc._connectors[0].prefer_cross_layer_blocks = False
|
||||||
|
mc._connectors[1].prefer_cross_layer_blocks = True
|
||||||
|
assert mc.prefer_cross_layer_blocks is False
|
||||||
|
|
||||||
|
mc._connectors[0].prefer_cross_layer_blocks = True
|
||||||
|
mc._connectors[1].prefer_cross_layer_blocks = True
|
||||||
|
assert mc.prefer_cross_layer_blocks is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_connector_worker_metadata(mc):
|
||||||
|
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
|
||||||
|
def __init__(self, data: set[str]):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata):
|
||||||
|
def aggregate(
|
||||||
|
self, other: KVConnectorWorkerMetadata
|
||||||
|
) -> KVConnectorWorkerMetadata:
|
||||||
|
assert isinstance(other, MockConnectorWorkerMetadata)
|
||||||
|
return MockConnectorWorkerMetadata0(data=self.data | other.data)
|
||||||
|
|
||||||
|
class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata):
|
||||||
|
def aggregate(
|
||||||
|
self, other: KVConnectorWorkerMetadata
|
||||||
|
) -> KVConnectorWorkerMetadata:
|
||||||
|
assert isinstance(other, MockConnectorWorkerMetadata)
|
||||||
|
return MockConnectorWorkerMetadata1(data=self.data | other.data)
|
||||||
|
|
||||||
|
# -------------------- test build_worker_connector_meta -------------------
|
||||||
|
|
||||||
|
# both connectors return None
|
||||||
|
mc._connectors[0].build_connector_worker_meta.return_value = None
|
||||||
|
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||||
|
assert mc.build_connector_worker_meta() is None
|
||||||
|
|
||||||
|
# only first connector returns None
|
||||||
|
worker_meta1a = MockConnectorWorkerMetadata1({"1a"})
|
||||||
|
mc._connectors[0].build_connector_worker_meta.return_value = None
|
||||||
|
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a
|
||||||
|
mc_worker_meta_none_1a = mc.build_connector_worker_meta()
|
||||||
|
assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata)
|
||||||
|
assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a)
|
||||||
|
|
||||||
|
# only second connector returns None
|
||||||
|
worker_meta0a = MockConnectorWorkerMetadata0({"0a"})
|
||||||
|
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a
|
||||||
|
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||||
|
mc_worker_meta_0a_none = mc.build_connector_worker_meta()
|
||||||
|
assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata)
|
||||||
|
assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None)
|
||||||
|
|
||||||
|
# both connectors do not return None
|
||||||
|
worker_meta0b = MockConnectorWorkerMetadata0({"0b"})
|
||||||
|
worker_meta1b = MockConnectorWorkerMetadata1({"1b"})
|
||||||
|
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
|
||||||
|
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b
|
||||||
|
mc_worker_meta_0b_1b = mc.build_connector_worker_meta()
|
||||||
|
assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata)
|
||||||
|
assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b)
|
||||||
|
|
||||||
|
# ----------------------------- test aggregate ----------------------------
|
||||||
|
|
||||||
|
# aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"})
|
||||||
|
mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a)
|
||||||
|
assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata)
|
||||||
|
assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a)
|
||||||
|
|
||||||
|
# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)
|
||||||
|
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
|
||||||
|
mc._connectors[1].build_connector_worker_meta.return_value = None
|
||||||
|
mc_worker_meta_0b_none = mc.build_connector_worker_meta()
|
||||||
|
mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none)
|
||||||
|
assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata)
|
||||||
|
assert mc_worker_meta_0a_0b.metadata[1] is None
|
||||||
|
connector0_md = mc_worker_meta_0a_0b.metadata[0]
|
||||||
|
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
|
||||||
|
assert connector0_md.data == {"0a", "0b"}
|
||||||
|
|
||||||
|
# aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"})
|
||||||
|
mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b)
|
||||||
|
assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata)
|
||||||
|
metadata = mc_worker_meta_01a_01b.metadata
|
||||||
|
assert len(metadata) == 2
|
||||||
|
connector0_md, connector1_md = metadata
|
||||||
|
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
|
||||||
|
assert isinstance(connector1_md, MockConnectorWorkerMetadata1)
|
||||||
|
assert connector0_md.data == {"0a", "0b"}
|
||||||
|
assert connector1_md.data == {"1a", "1b"}
|
||||||
|
|
||||||
|
# ---------------------- test update_connector_output ---------------------
|
||||||
|
|
||||||
|
def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None):
|
||||||
|
def _verify_worker_metadata(connector_output: KVConnectorOutput):
|
||||||
|
worker_meta = connector_output.kv_connector_worker_meta
|
||||||
|
if expected_metadata is None:
|
||||||
|
assert worker_meta is None
|
||||||
|
return
|
||||||
|
|
||||||
|
assert isinstance(worker_meta, MockConnectorWorkerMetadata)
|
||||||
|
assert type(worker_meta) is type(expected_metadata)
|
||||||
|
assert expected_metadata.data == worker_meta.data
|
||||||
|
|
||||||
|
return _verify_worker_metadata
|
||||||
|
|
||||||
|
def assert_update_connector_output_called(mc: MultiConnector):
|
||||||
|
for c in mc._connectors:
|
||||||
|
c.update_connector_output.assert_called_once()
|
||||||
|
c.update_connector_output.reset_mock()
|
||||||
|
|
||||||
|
# no worker meta
|
||||||
|
kv_connector_output = KVConnectorOutput()
|
||||||
|
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None)
|
||||||
|
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None)
|
||||||
|
mc.update_connector_output(kv_connector_output)
|
||||||
|
assert_update_connector_output_called(mc)
|
||||||
|
|
||||||
|
# multi worker meta
|
||||||
|
kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b
|
||||||
|
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(
|
||||||
|
connector0_md
|
||||||
|
)
|
||||||
|
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(
|
||||||
|
connector1_md
|
||||||
|
)
|
||||||
|
mc.update_connector_output(kv_connector_output)
|
||||||
|
assert_update_connector_output_called(mc)
|
||||||
|
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class KVOutputAggregator:
|
|||||||
finished_sending = set[str]()
|
finished_sending = set[str]()
|
||||||
finished_recving = set[str]()
|
finished_recving = set[str]()
|
||||||
aggregated_kv_connector_stats = None
|
aggregated_kv_connector_stats = None
|
||||||
|
aggregated_kv_connector_worker_meta = None
|
||||||
combined_kv_cache_events = None
|
combined_kv_cache_events = None
|
||||||
invalid_block_ids = set[int]()
|
invalid_block_ids = set[int]()
|
||||||
for model_runner_output in outputs:
|
for model_runner_output in outputs:
|
||||||
@@ -127,6 +128,17 @@ class KVOutputAggregator:
|
|||||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Aggregate kv_connector_worker_meta from all workers.
|
||||||
|
if aggregated_kv_connector_worker_meta is None:
|
||||||
|
# Use the first worker's kv_connector_worker_meta as accumulator.
|
||||||
|
aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta
|
||||||
|
elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta:
|
||||||
|
aggregated_kv_connector_worker_meta = (
|
||||||
|
aggregated_kv_connector_worker_meta.aggregate(
|
||||||
|
kv_connector_worker_meta
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Combine kv_cache_events from all workers.
|
# Combine kv_cache_events from all workers.
|
||||||
if combined_kv_cache_events is None:
|
if combined_kv_cache_events is None:
|
||||||
# Use the first worker's kv_cache events as start event list.
|
# Use the first worker's kv_cache events as start event list.
|
||||||
@@ -151,6 +163,7 @@ class KVOutputAggregator:
|
|||||||
finished_recving=finished_recving or None,
|
finished_recving=finished_recving or None,
|
||||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||||
kv_cache_events=combined_kv_cache_events or None,
|
kv_cache_events=combined_kv_cache_events or None,
|
||||||
|
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
|
||||||
invalid_block_ids=invalid_block_ids,
|
invalid_block_ids=invalid_block_ids,
|
||||||
expected_finished_count=self._expected_finished_count,
|
expected_finished_count=self._expected_finished_count,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,8 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
get_finished() - called with ids of finished requests, returns
|
get_finished() - called with ids of finished requests, returns
|
||||||
ids of requests that have completed async sending/recving.
|
ids of requests that have completed async sending/recving.
|
||||||
|
build_connector_worker_meta() - builds metadata to be sent
|
||||||
|
back to the scheduler-side connector
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
@@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024
|
|||||||
|
|
||||||
class KVConnectorMetadata(ABC): # noqa: B024
|
class KVConnectorMetadata(ABC): # noqa: B024
|
||||||
"""
|
"""
|
||||||
Abstract Metadata used to communicate between the
|
Abstract Metadata used to communicate
|
||||||
Scheduler KVConnector and Worker KVConnector.
|
Scheduler KVConnector -> Worker KVConnector.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class KVConnectorWorkerMetadata(ABC):
|
||||||
|
"""
|
||||||
|
Abstract Metadata used to communicate back
|
||||||
|
Worker KVConnector -> Scheduler KVConnector.
|
||||||
|
|
||||||
|
Each worker can output its own metadata.
|
||||||
|
For a single engine step, all metadata objects returned by workers
|
||||||
|
will be aggregated using the `aggregate` method below, before
|
||||||
|
being passed to the Scheduler KVConnector.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def aggregate(
|
||||||
|
self, other: "KVConnectorWorkerMetadata"
|
||||||
|
) -> "KVConnectorWorkerMetadata":
|
||||||
|
"""
|
||||||
|
Aggregate metadata with another `KVConnectorWorkerMetadata` object.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class KVConnectorBase_V1(ABC):
|
class KVConnectorBase_V1(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for KV connectors.
|
Base class for KV connectors.
|
||||||
@@ -409,6 +432,16 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
|
||||||
|
"""
|
||||||
|
Build the KVConnector worker metadata for this engine step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
KVConnectorWorkerMetadata: the worker metadata.
|
||||||
|
None if no worker metadata is available.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
# ==============================
|
# ==============================
|
||||||
# Scheduler-side methods
|
# Scheduler-side methods
|
||||||
# ==============================
|
# ==============================
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
|||||||
KVConnectorHandshakeMetadata,
|
KVConnectorHandshakeMetadata,
|
||||||
KVConnectorMetadata,
|
KVConnectorMetadata,
|
||||||
KVConnectorRole,
|
KVConnectorRole,
|
||||||
|
KVConnectorWorkerMetadata,
|
||||||
)
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||||
KVConnectorPromMetrics,
|
KVConnectorPromMetrics,
|
||||||
@@ -45,6 +46,26 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
|
|||||||
extra_async_saves: dict[str, int] | None = None
|
extra_async_saves: dict[str, int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiKVConnectorWorkerMetadata(KVConnectorWorkerMetadata):
|
||||||
|
metadata: tuple[KVConnectorWorkerMetadata | None, ...]
|
||||||
|
|
||||||
|
def aggregate(self, other: KVConnectorWorkerMetadata) -> KVConnectorWorkerMetadata:
|
||||||
|
assert isinstance(other, MultiKVConnectorWorkerMetadata)
|
||||||
|
|
||||||
|
assert len(self.metadata) == len(other.metadata)
|
||||||
|
metadata_list = []
|
||||||
|
for metadata1, metadata2 in zip(self.metadata, other.metadata):
|
||||||
|
if metadata1 is None:
|
||||||
|
metadata_list.append(metadata2)
|
||||||
|
elif metadata2 is None:
|
||||||
|
metadata_list.append(metadata1)
|
||||||
|
else:
|
||||||
|
metadata_list.append(metadata1.aggregate(metadata2))
|
||||||
|
|
||||||
|
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiKVConnectorStats(KVConnectorStats):
|
class MultiKVConnectorStats(KVConnectorStats):
|
||||||
"""
|
"""
|
||||||
@@ -304,6 +325,18 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
# Currently no connectors return non-None
|
# Currently no connectors return non-None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
|
||||||
|
metadata_list: list[KVConnectorWorkerMetadata | None] | None = None
|
||||||
|
for i, c in enumerate(self._connectors):
|
||||||
|
kv_connector_worker_meta = c.build_connector_worker_meta()
|
||||||
|
if metadata_list is None and kv_connector_worker_meta is not None:
|
||||||
|
metadata_list = [None] * i
|
||||||
|
if metadata_list is not None:
|
||||||
|
metadata_list.append(kv_connector_worker_meta)
|
||||||
|
if metadata_list is None:
|
||||||
|
return None
|
||||||
|
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))
|
||||||
|
|
||||||
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
|
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
|
||||||
# method for the MultiConnector. It should be able to get events from
|
# method for the MultiConnector. It should be able to get events from
|
||||||
# multiple connectors, handling the case where only a subset of the
|
# multiple connectors, handling the case where only a subset of the
|
||||||
@@ -361,8 +394,25 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||||
for c in self._connectors:
|
multi_connector_worker_meta: MultiKVConnectorWorkerMetadata | None = None
|
||||||
|
if connector_output.kv_connector_worker_meta is not None:
|
||||||
|
assert isinstance(
|
||||||
|
connector_output.kv_connector_worker_meta,
|
||||||
|
MultiKVConnectorWorkerMetadata,
|
||||||
|
)
|
||||||
|
multi_connector_worker_meta = connector_output.kv_connector_worker_meta
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i, c in enumerate(self._connectors):
|
||||||
|
if multi_connector_worker_meta is not None:
|
||||||
|
# set the connector-specific worker metadata
|
||||||
|
connector_output.kv_connector_worker_meta = (
|
||||||
|
multi_connector_worker_meta.metadata[i]
|
||||||
|
)
|
||||||
c.update_connector_output(connector_output)
|
c.update_connector_output(connector_output)
|
||||||
|
finally:
|
||||||
|
# restore kv_connector_worker_meta
|
||||||
|
connector_output.kv_connector_worker_meta = multi_connector_worker_meta
|
||||||
|
|
||||||
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,9 +14,13 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.distributed.kv_events import KVConnectorKVEvents
|
from vllm.distributed.kv_events import KVConnectorKVEvents
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorWorkerMetadata,
|
||||||
|
)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||||
else:
|
else:
|
||||||
KVConnectorStats = object
|
KVConnectorStats = object
|
||||||
|
KVConnectorWorkerMetadata = object
|
||||||
KVConnectorKVEvents = object
|
KVConnectorKVEvents = object
|
||||||
|
|
||||||
|
|
||||||
@@ -142,6 +146,7 @@ class KVConnectorOutput:
|
|||||||
finished_recving: set[str] | None = None
|
finished_recving: set[str] | None = None
|
||||||
kv_connector_stats: KVConnectorStats | None = None
|
kv_connector_stats: KVConnectorStats | None = None
|
||||||
kv_cache_events: KVConnectorKVEvents | None = None
|
kv_cache_events: KVConnectorKVEvents | None = None
|
||||||
|
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None
|
||||||
# IDs of externally computed KV blocks that failed to load.
|
# IDs of externally computed KV blocks that failed to load.
|
||||||
# Requests referencing these blocks should be rescheduled to recompute them
|
# Requests referencing these blocks should be rescheduled to recompute them
|
||||||
invalid_block_ids: set[int] = field(default_factory=set)
|
invalid_block_ids: set[int] = field(default_factory=set)
|
||||||
@@ -159,6 +164,7 @@ class KVConnectorOutput:
|
|||||||
and not self.kv_connector_stats
|
and not self.kv_connector_stats
|
||||||
and not self.kv_cache_events
|
and not self.kv_cache_events
|
||||||
and not self.invalid_block_ids
|
and not self.invalid_block_ids
|
||||||
|
and not self.kv_connector_worker_meta
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class KVConnectorModelRunnerMixin:
|
|||||||
|
|
||||||
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
output.kv_connector_stats = kv_connector.get_kv_connector_stats()
|
||||||
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
|
||||||
|
output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta()
|
||||||
|
|
||||||
if not defer_finalize:
|
if not defer_finalize:
|
||||||
kv_connector.clear_connector_metadata()
|
kv_connector.clear_connector_metadata()
|
||||||
|
|||||||
Reference in New Issue
Block a user