[KV Connector] Test async mode in scheduler tests (#28550)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-11-13 23:30:59 +00:00
committed by GitHub
parent e64011f29a
commit 6e25b1cddf
3 changed files with 165 additions and 45 deletions

View File

@@ -31,11 +31,11 @@ from vllm.v1.kv_cache_interface import (
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler, mock_kv
pytestmark = pytest.mark.cpu_test
@@ -888,27 +888,65 @@ def _step_until_done(
all_finished = all_done
def test_kv_connector_basic():
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cyle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values():
assert req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# No model execution yet
EMPTY_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, EMPTY_OUTPUT)
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
kv_connector_output=KVConnectorOutput(finished_recving=req_ids),
)
scheduler.update_from_output(output, MODEL_RUNNER_OUTPUT)
for req_id in req_ids:
assert req_id in scheduler.finished_recving_kv_req_ids
@pytest.mark.parametrize("is_async", [False, True])
def test_kv_connector_basic(is_async: bool):
"""
Test whether Scheduler with KVConnector schedules tokens, allocates
memory, and cleans up requests as expected under normal operation.
"""
# Setup Scheduler.
BLOCK_SIZE = 16
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE,
)
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
BLOCK_SIZE = scheduler.cache_config.block_size
# Mock External Cache Hit.
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
######################################################
# FIRST SET OF REQUESTS - External Hit Only
@@ -928,6 +966,9 @@ def test_kv_connector_basic():
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
@@ -978,6 +1019,9 @@ def test_kv_connector_basic():
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
@@ -1020,17 +1064,10 @@ def test_external_prefix_cache_metrics():
"""
# Setup Scheduler.
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler(
enable_prefix_caching=False,
use_kv_connector=True,
)
# Mock connector to simulate a partial external cache hit
NUM_MATCHED_NEW_TOKENS = 4
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
)
# --- Prepare simple requests ---
@@ -1085,21 +1122,16 @@ def test_kv_connector_unable_to_allocate(use_ec_connector, ec_role):
# Setup Scheduler With Mock External Cache Hit.
BLOCK_SIZE = 4
NUM_BLOCKS = 10
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
use_ec_connector=use_ec_connector,
ec_role=ec_role,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests. The second request will not be able to
# allocate slots because it will not have enough blocks.
@@ -1174,9 +1206,10 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
BLOCK_SIZE = 2
# NOTE: there is 1 null block, so this is 6 blocks.
NUM_BLOCKS = 7
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False),
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
# encoder connector should not affect test results
@@ -1184,13 +1217,6 @@ def test_kv_connector_handles_preemption(use_ec_connector, ec_role):
ec_role=ec_role,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
scheduler.connector.get_num_new_matched_tokens.return_value = (
NUM_MATCHED_NEW_TOKENS,
False,
)
# Create two requests.
# Both can be scheduled at first, but the second request
# will be preempted and re-scheduled.