[KV Connector] Test async mode in scheduler tests (#28550)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -3,7 +3,8 @@
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from itertools import count
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain, count
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -18,13 +19,18 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector,
|
||||
)
|
||||
from vllm.utils.hashing import sha256
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
@@ -307,6 +313,82 @@ class TestSharedStorageConnector(SharedStorageConnector):
|
||||
return attr
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MockKVConfig:
|
||||
matched_tokens: int = 0
|
||||
is_async: bool = False
|
||||
|
||||
|
||||
class MockKVConnectorMetadata(KVConnectorMetadata):
|
||||
def __init__(self):
|
||||
# Scheduler tests check metadata.requests
|
||||
self.requests: list = []
|
||||
|
||||
|
||||
class MockKVConnector(KVConnectorBase_V1):
|
||||
"""Mock KV connector for scheduler tests, supporting both sync and async mode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: KVCacheConfig | None = None,
|
||||
):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
extra_config = self._kv_transfer_config.kv_connector_extra_config
|
||||
self.config = MockKVConfig(
|
||||
matched_tokens=extra_config["matched_tokens"],
|
||||
is_async=extra_config["is_async"],
|
||||
)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: Request,
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
return (self.config.matched_tokens, self.config.is_async)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: KVCacheBlocks,
|
||||
num_external_tokens: int,
|
||||
):
|
||||
pass
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
metadata = MockKVConnectorMetadata()
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
for req_id in chain(
|
||||
(req.req_id for req in scheduler_output.scheduled_new_reqs),
|
||||
(
|
||||
req_id
|
||||
for req_id in cached_reqs.req_ids
|
||||
if req_id in cached_reqs.resumed_req_ids
|
||||
),
|
||||
):
|
||||
metadata.requests.append({"req_id": req_id})
|
||||
return metadata
|
||||
|
||||
def start_load_kv(self, kv_caches, finished_req_ids):
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name):
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"TestSharedStorageConnector", __name__, TestSharedStorageConnector.__name__
|
||||
)
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MockKVConnector", __name__, MockKVConnector.__name__
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user