[KV Connector][Metrics] Do not count local prefix cache hits in connector queries (#30522)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2026-02-05 07:57:27 +00:00
committed by GitHub
parent 6abb0454ad
commit 2abd97592f
3 changed files with 115 additions and 31 deletions

View File

@@ -1136,7 +1136,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, EMPTY_OUTPUT)
initial_ecos = scheduler.update_from_output(output, EMPTY_OUTPUT)
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
@@ -1156,6 +1156,8 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
for req_id in req_ids:
assert req_id in scheduler.finished_recving_kv_req_ids
return initial_ecos
@pytest.mark.parametrize("is_async", [False, True])
def test_kv_connector_basic(is_async: bool):
@@ -1286,29 +1288,72 @@ def test_kv_connector_basic(is_async: bool):
@pytest.mark.parametrize("is_async", [False, True])
def test_external_prefix_cache_metrics(is_async: bool):
@pytest.mark.parametrize("local_cache_hits", [False, True])
def test_external_prefix_cache_metrics(is_async: bool, local_cache_hits: bool):
"""
Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits.
"""
BLOCK_SIZE = 16
if local_cache_hits:
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 # 32 tokens
NUM_LOCAL_HITS = NUM_MATCHED_NEW_TOKENS * 2 # 64 tokens
NUM_REQUESTS = 1
NUM_TOKENS = NUM_LOCAL_HITS * 2 # 128 tokens
else:
NUM_MATCHED_NEW_TOKENS = 4
NUM_LOCAL_HITS = 0
NUM_REQUESTS = 2
NUM_TOKENS = 8 # 8 tokens
# Setup Scheduler.
NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler(
enable_prefix_caching=False,
enable_prefix_caching=local_cache_hits,
use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
block_size=BLOCK_SIZE,
)
# --- Prepare simple requests ---
NUM_REQUESTS = 2
NUM_TOKENS = 8
if local_cache_hits:
# First, establish local cache by running a request to completion
requests = create_requests(
num_requests=1,
num_tokens=NUM_LOCAL_HITS,
max_tokens=2,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
# Run first request to completion to establish local cache
output = scheduler.schedule()
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
_ = scheduler.schedule()
# --- Prepare test requests ---
MAX_TOKENS = 2
requests = create_requests(
num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS,
block_size=BLOCK_SIZE,
)
req_ids = []
req_to_index = {}
@@ -1317,8 +1362,9 @@ def test_external_prefix_cache_metrics(is_async: bool):
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
initial_ecos = None
if is_async:
_step_until_kv_transfer_finished(scheduler, req_ids)
initial_ecos = _step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output ---
output = scheduler.schedule()
@@ -1338,10 +1384,23 @@ def test_external_prefix_cache_metrics(is_async: bool):
assert ecos is not None and len(ecos) > 0
assert ecos[0].scheduler_stats is not None
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
if local_cache_hits:
# For async, local cache stats come from the first step
if initial_ecos:
local_stats = initial_ecos[0].scheduler_stats.prefix_cache_stats
else:
local_stats = ecos[0].scheduler_stats.prefix_cache_stats
assert local_stats is not None
assert local_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert local_stats.hits == NUM_LOCAL_HITS * NUM_REQUESTS
if initial_ecos:
external_stats = initial_ecos[0].scheduler_stats.connector_prefix_cache_stats
else:
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
assert external_stats is not None
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
assert external_stats.queries == (NUM_TOKENS - NUM_LOCAL_HITS) * NUM_REQUESTS
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
assert external_stats.requests == NUM_REQUESTS
assert external_stats.preempted_requests == 0

View File

@@ -281,6 +281,17 @@ def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
f"(hash should be None), but hash is still {block.block_hash}"
)
# Verify connector prefix cache stats:
# - queries = num_prompt_tokens (total tokens not in local cache)
# - hits = num_external_computed_tokens (tokens loaded externally)
assert engine_outputs.scheduler_stats is not None
stats = engine_outputs.scheduler_stats
assert stats.connector_prefix_cache_stats is not None
conn_stats = stats.connector_prefix_cache_stats
assert conn_stats.requests == 1
assert conn_stats.queries == num_prompt_tokens
assert conn_stats.hits == num_external_computed_tokens
def test_async_recompute_blocks_not_cached_when_invalid(
recompute_scheduler: Scheduler,
@@ -364,7 +375,9 @@ def test_async_recompute_blocks_not_cached_when_invalid(
with patch.object(
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
):
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
outputs = recompute_scheduler.update_from_output(
scheduler_output, model_runner_output
)
# verify evict_blocks was NOT called (async blocks excluded from eviction)
assert len(evict_blocks_calls) == 0, (
@@ -386,6 +399,19 @@ def test_async_recompute_blocks_not_cached_when_invalid(
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
)
# Verify connector prefix cache stats:
# - queries = num_prompt_tokens (total tokens not in local cache)
# - hits = num_external_computed_tokens (tokens loaded externally)
assert len(outputs) == 1
engine_outputs = next(iter(outputs.values()))
assert engine_outputs.scheduler_stats is not None
stats = engine_outputs.scheduler_stats
assert stats.connector_prefix_cache_stats is not None
conn_stats = stats.connector_prefix_cache_stats
assert conn_stats.requests == 1
assert conn_stats.queries == num_prompt_tokens
assert conn_stats.hits == num_external_computed_tokens
# now simulate async transfer completing
model_runner_output_2 = create_model_runner_output(
reqs=[],

View File

@@ -586,6 +586,7 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens = 0
load_kv_async = False
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
# Get already-cached tokens.
if request.num_computed_tokens == 0:
@@ -613,6 +614,11 @@ class Scheduler(SchedulerInterface):
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens
connector_prefix_cache_queries = (
request.num_tokens - num_new_local_computed_tokens
)
connector_prefix_cache_hits = num_external_computed_tokens
# Total computed tokens (local + external).
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
@@ -728,6 +734,15 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.get_blocks(request_id),
num_external_computed_tokens,
)
if (
self.connector_prefix_cache_stats is not None
and connector_prefix_cache_queries != 0
):
self.connector_prefix_cache_stats.record(
num_tokens=connector_prefix_cache_queries,
num_hits=connector_prefix_cache_hits,
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
@@ -739,8 +754,6 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self._update_connector_prefix_cache_stats(request)
self.running.append(request)
if self.log_stats:
request.record_event(
@@ -1805,7 +1818,10 @@ class Scheduler(SchedulerInterface):
return None
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
assert prefix_cache_stats is not None
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats()
connector_prefix_cache_stats: PrefixCacheStats | None = None
if self.connector_prefix_cache_stats is not None:
connector_prefix_cache_stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
eviction_events = (
self.kv_metrics_collector.drain_events()
if self.kv_metrics_collector is not None
@@ -1866,23 +1882,6 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods
########################################################################
def _update_connector_prefix_cache_stats(self, request: Request) -> None:
if self.connector_prefix_cache_stats is None:
return
self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=request.num_external_computed_tokens,
preempted=request.num_preemptions > 0,
)
def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None:
if self.connector_prefix_cache_stats is None:
return None
stats = self.connector_prefix_cache_stats
self.connector_prefix_cache_stats = PrefixCacheStats()
return stats
def get_kv_connector(self) -> KVConnectorBase_V1 | None:
return self.connector