diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 92122bcb0..2d9834d2e 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -47,7 +47,7 @@ def create_scheduler( enable_prefix_caching: bool = False, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, - use_kv_connector: None | bool | MockKVConfig = None, + use_kv_connector: None | bool | str | MockKVConfig = None, num_blocks: int = 10000, block_size: int = 16, max_model_len: int | None = None, @@ -107,6 +107,11 @@ def create_scheduler( "is_async": use_kv_connector.is_async, }, ) + elif isinstance(use_kv_connector, str): + kv_transfer_config = KVTransferConfig( + kv_connector=use_kv_connector, + kv_role="kv_both", + ) elif use_kv_connector: kv_transfer_config = KVTransferConfig( kv_connector="ExampleConnector", diff --git a/tests/v1/kv_connector/unit/test_scheduler_kv_connector_override.py b/tests/v1/kv_connector/unit/test_scheduler_kv_connector_override.py new file mode 100644 index 000000000..2834647fe --- /dev/null +++ b/tests/v1/kv_connector/unit/test_scheduler_kv_connector_override.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import MagicMock, patch + +import pytest + +import vllm.plugins as plugins_module +from tests.v1.core.utils import create_requests, create_scheduler +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory, +) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, +) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request + + +class DummyConnectorMetadata(KVConnectorMetadata): + def __init__(self, block_hashes_by_req: dict[str, list[BlockHash]]): + self.block_hashes_by_req = block_hashes_by_req + + +class DummyKVConnector(KVConnectorBase_V1): + def __init__(self, vllm_config, role, kv_cache_config=None): + super().__init__(vllm_config, role, kv_cache_config) + + def get_num_new_matched_tokens( + self, request: Request, num_computed_tokens: int + ) -> tuple[int | None, bool]: + return (0, False) + + def update_state_after_alloc( + self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int + ): + pass + + def build_connector_meta( + self, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + block_hashes_by_req = getattr(scheduler_output, "block_hashes_by_req", None) + assert block_hashes_by_req is not None, ( + "DummyKVConnector expected 'block_hashes_by_req' on scheduler_output" + ) + return DummyConnectorMetadata( + block_hashes_by_req=block_hashes_by_req, + ) + + def start_load_kv(self, kv_caches, finished_req_ids): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + pass + + def wait_for_save(self): + pass + + +def _my_plugin(): + """Registers the dummy KV connector and overrides _build_kv_connector_meta""" + KVConnectorFactory.register_connector( + "DummyKVConnector", + __name__, + DummyKVConnector.__name__, + ) + + def _custom_build_kv_connector_meta( + self, connector: KVConnectorBase_V1, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + block_hashes_by_req: dict[str, list[BlockHash]] = {} + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + block_hashes_by_req[req_id] = request.block_hashes + + scheduler_output.block_hashes_by_req = block_hashes_by_req # type: ignore[attr-defined] + return connector.build_connector_meta(scheduler_output) + + Scheduler._build_kv_connector_meta = _custom_build_kv_connector_meta + + +@pytest.fixture +def _load_plugin(): + """Load the fake plugin through the real load_general_plugins() path.""" + ep = MagicMock() + ep.name = "dummy_kv_connector_plugin" + ep.value = f"{__name__}:_my_plugin" + ep.load.return_value = _my_plugin + + # Reset the global guard so load_general_plugins() actually runs. + plugins_module.plugins_loaded = False + with patch("importlib.metadata.entry_points", return_value=[ep]): + plugins_module.load_general_plugins() + yield + # Reset again so other tests are not affected. + plugins_module.plugins_loaded = False + + +def test_connector_receives_block_hashes(_load_plugin): + block_size = 16 + num_tokens = 48 # 3 full blocks worth of tokens + scheduler = create_scheduler( + use_kv_connector="DummyKVConnector", block_size=block_size + ) + requests = create_requests( + num_requests=3, num_tokens=num_tokens, block_size=block_size + ) + for req in requests: + scheduler.add_request(req) + + output = scheduler.schedule() + + # Verify the connector metadata was built with block hashes. + meta = output.kv_connector_metadata + assert isinstance(meta, DummyConnectorMetadata) + assert len(meta.block_hashes_by_req) == 3 + + for req in requests: + assert req.request_id in meta.block_hashes_by_req + # Each request has num_tokens / block_size = 3 full block hashes. + assert len(meta.block_hashes_by_req[req.request_id]) == ( + num_tokens // block_size + ) + assert meta.block_hashes_by_req[req.request_id] == req.block_hashes diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ea2c2a6cd..486ce8deb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -910,9 +910,7 @@ class Scheduler(SchedulerInterface): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta: KVConnectorMetadata = self.connector.build_connector_meta( - scheduler_output - ) + meta = self._build_kv_connector_meta(self.connector, scheduler_output) scheduler_output.kv_connector_metadata = meta # Build the connector meta for ECConnector @@ -926,6 +924,11 @@ class Scheduler(SchedulerInterface): self._update_after_schedule(scheduler_output) return scheduler_output + def _build_kv_connector_meta( + self, connector: KVConnectorBase_V1, scheduler_output: SchedulerOutput + ) -> KVConnectorMetadata: + return connector.build_connector_meta(scheduler_output) + def _preempt_request(self, request: Request, timestamp: float) -> None: """Preempt a request and put it back to the waiting queue.