[KV Connector][Metrics] Do not count local prefix cache hits in connector queries (#30522)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -1136,7 +1136,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, EMPTY_OUTPUT)
|
||||
initial_ecos = scheduler.update_from_output(output, EMPTY_OUTPUT)
|
||||
|
||||
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
||||
output = scheduler.schedule()
|
||||
@@ -1156,6 +1156,8 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
for req_id in req_ids:
|
||||
assert req_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
return initial_ecos
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
def test_kv_connector_basic(is_async: bool):
|
||||
@@ -1286,29 +1288,72 @@ def test_kv_connector_basic(is_async: bool):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
def test_external_prefix_cache_metrics(is_async: bool):
|
||||
@pytest.mark.parametrize("local_cache_hits", [False, True])
|
||||
def test_external_prefix_cache_metrics(is_async: bool, local_cache_hits: bool):
|
||||
"""
|
||||
Verify connector prefix cache metrics are updated
|
||||
correctly when the scheduler processes requests with KV connector hits.
|
||||
"""
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
if local_cache_hits:
|
||||
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 # 32 tokens
|
||||
NUM_LOCAL_HITS = NUM_MATCHED_NEW_TOKENS * 2 # 64 tokens
|
||||
NUM_REQUESTS = 1
|
||||
NUM_TOKENS = NUM_LOCAL_HITS * 2 # 128 tokens
|
||||
else:
|
||||
NUM_MATCHED_NEW_TOKENS = 4
|
||||
NUM_LOCAL_HITS = 0
|
||||
NUM_REQUESTS = 2
|
||||
NUM_TOKENS = 8 # 8 tokens
|
||||
|
||||
# Setup Scheduler.
|
||||
NUM_MATCHED_NEW_TOKENS = 4
|
||||
scheduler = create_scheduler(
|
||||
enable_prefix_caching=False,
|
||||
enable_prefix_caching=local_cache_hits,
|
||||
use_kv_connector=mock_kv(
|
||||
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
|
||||
),
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# --- Prepare simple requests ---
|
||||
NUM_REQUESTS = 2
|
||||
NUM_TOKENS = 8
|
||||
if local_cache_hits:
|
||||
# First, establish local cache by running a request to completion
|
||||
requests = create_requests(
|
||||
num_requests=1,
|
||||
num_tokens=NUM_LOCAL_HITS,
|
||||
max_tokens=2,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
if is_async:
|
||||
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||
|
||||
# Run first request to completion to establish local cache
|
||||
output = scheduler.schedule()
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
_step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT)
|
||||
_ = scheduler.schedule()
|
||||
|
||||
# --- Prepare test requests ---
|
||||
MAX_TOKENS = 2
|
||||
requests = create_requests(
|
||||
num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
@@ -1317,8 +1362,9 @@ def test_external_prefix_cache_metrics(is_async: bool):
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
initial_ecos = None
|
||||
if is_async:
|
||||
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||
initial_ecos = _step_until_kv_transfer_finished(scheduler, req_ids)
|
||||
|
||||
# --- Trigger scheduling and simulate model output ---
|
||||
output = scheduler.schedule()
|
||||
@@ -1338,10 +1384,23 @@ def test_external_prefix_cache_metrics(is_async: bool):
|
||||
assert ecos is not None and len(ecos) > 0
|
||||
assert ecos[0].scheduler_stats is not None
|
||||
|
||||
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
|
||||
if local_cache_hits:
|
||||
# For async, local cache stats come from the first step
|
||||
if initial_ecos:
|
||||
local_stats = initial_ecos[0].scheduler_stats.prefix_cache_stats
|
||||
else:
|
||||
local_stats = ecos[0].scheduler_stats.prefix_cache_stats
|
||||
assert local_stats is not None
|
||||
assert local_stats.queries == NUM_TOKENS * NUM_REQUESTS
|
||||
assert local_stats.hits == NUM_LOCAL_HITS * NUM_REQUESTS
|
||||
|
||||
if initial_ecos:
|
||||
external_stats = initial_ecos[0].scheduler_stats.connector_prefix_cache_stats
|
||||
else:
|
||||
external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats
|
||||
assert external_stats is not None
|
||||
|
||||
assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS
|
||||
assert external_stats.queries == (NUM_TOKENS - NUM_LOCAL_HITS) * NUM_REQUESTS
|
||||
assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS
|
||||
assert external_stats.requests == NUM_REQUESTS
|
||||
assert external_stats.preempted_requests == 0
|
||||
|
||||
@@ -281,6 +281,17 @@ def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler):
|
||||
f"(hash should be None), but hash is still {block.block_hash}"
|
||||
)
|
||||
|
||||
# Verify connector prefix cache stats:
|
||||
# - queries = num_prompt_tokens (total tokens not in local cache)
|
||||
# - hits = num_external_computed_tokens (tokens loaded externally)
|
||||
assert engine_outputs.scheduler_stats is not None
|
||||
stats = engine_outputs.scheduler_stats
|
||||
assert stats.connector_prefix_cache_stats is not None
|
||||
conn_stats = stats.connector_prefix_cache_stats
|
||||
assert conn_stats.requests == 1
|
||||
assert conn_stats.queries == num_prompt_tokens
|
||||
assert conn_stats.hits == num_external_computed_tokens
|
||||
|
||||
|
||||
def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
recompute_scheduler: Scheduler,
|
||||
@@ -364,7 +375,9 @@ def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
with patch.object(
|
||||
recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy
|
||||
):
|
||||
recompute_scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
outputs = recompute_scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
|
||||
# verify evict_blocks was NOT called (async blocks excluded from eviction)
|
||||
assert len(evict_blocks_calls) == 0, (
|
||||
@@ -386,6 +399,19 @@ def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
f"Block {invalid_block_id} hash should be None but is {block.block_hash}"
|
||||
)
|
||||
|
||||
# Verify connector prefix cache stats:
|
||||
# - queries = num_prompt_tokens (total tokens not in local cache)
|
||||
# - hits = num_external_computed_tokens (tokens loaded externally)
|
||||
assert len(outputs) == 1
|
||||
engine_outputs = next(iter(outputs.values()))
|
||||
assert engine_outputs.scheduler_stats is not None
|
||||
stats = engine_outputs.scheduler_stats
|
||||
assert stats.connector_prefix_cache_stats is not None
|
||||
conn_stats = stats.connector_prefix_cache_stats
|
||||
assert conn_stats.requests == 1
|
||||
assert conn_stats.queries == num_prompt_tokens
|
||||
assert conn_stats.hits == num_external_computed_tokens
|
||||
|
||||
# now simulate async transfer completing
|
||||
model_runner_output_2 = create_model_runner_output(
|
||||
reqs=[],
|
||||
|
||||
@@ -586,6 +586,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
load_kv_async = False
|
||||
connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0
|
||||
|
||||
# Get already-cached tokens.
|
||||
if request.num_computed_tokens == 0:
|
||||
@@ -613,6 +614,11 @@ class Scheduler(SchedulerInterface):
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
num_external_computed_tokens = ext_tokens
|
||||
|
||||
connector_prefix_cache_queries = (
|
||||
request.num_tokens - num_new_local_computed_tokens
|
||||
)
|
||||
connector_prefix_cache_hits = num_external_computed_tokens
|
||||
|
||||
# Total computed tokens (local + external).
|
||||
num_computed_tokens = (
|
||||
num_new_local_computed_tokens + num_external_computed_tokens
|
||||
@@ -728,6 +734,15 @@ class Scheduler(SchedulerInterface):
|
||||
self.kv_cache_manager.get_blocks(request_id),
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
if (
|
||||
self.connector_prefix_cache_stats is not None
|
||||
and connector_prefix_cache_queries != 0
|
||||
):
|
||||
self.connector_prefix_cache_stats.record(
|
||||
num_tokens=connector_prefix_cache_queries,
|
||||
num_hits=connector_prefix_cache_hits,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
@@ -739,8 +754,6 @@ class Scheduler(SchedulerInterface):
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
self._update_connector_prefix_cache_stats(request)
|
||||
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
request.record_event(
|
||||
@@ -1805,7 +1818,10 @@ class Scheduler(SchedulerInterface):
|
||||
return None
|
||||
prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats()
|
||||
assert prefix_cache_stats is not None
|
||||
connector_prefix_cache_stats = self._make_connector_prefix_cache_stats()
|
||||
connector_prefix_cache_stats: PrefixCacheStats | None = None
|
||||
if self.connector_prefix_cache_stats is not None:
|
||||
connector_prefix_cache_stats = self.connector_prefix_cache_stats
|
||||
self.connector_prefix_cache_stats = PrefixCacheStats()
|
||||
eviction_events = (
|
||||
self.kv_metrics_collector.drain_events()
|
||||
if self.kv_metrics_collector is not None
|
||||
@@ -1866,23 +1882,6 @@ class Scheduler(SchedulerInterface):
|
||||
# KV Connector Related Methods
|
||||
########################################################################
|
||||
|
||||
def _update_connector_prefix_cache_stats(self, request: Request) -> None:
|
||||
if self.connector_prefix_cache_stats is None:
|
||||
return
|
||||
|
||||
self.connector_prefix_cache_stats.record(
|
||||
num_tokens=request.num_tokens,
|
||||
num_hits=request.num_external_computed_tokens,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||||
if self.connector_prefix_cache_stats is None:
|
||||
return None
|
||||
stats = self.connector_prefix_cache_stats
|
||||
self.connector_prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_kv_connector(self) -> KVConnectorBase_V1 | None:
|
||||
return self.connector
|
||||
|
||||
|
||||
Reference in New Issue
Block a user