diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index b29df468f..a1e3d09d2 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1136,7 +1136,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): prompt_logprobs_dict={}, pooler_output=[], ) - scheduler.update_from_output(output, EMPTY_OUTPUT) + initial_ecos = scheduler.update_from_output(output, EMPTY_OUTPUT) # Simulate KV transfer completion using KVConnectorOutput.finished_recving output = scheduler.schedule() @@ -1156,6 +1156,8 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): for req_id in req_ids: assert req_id in scheduler.finished_recving_kv_req_ids + return initial_ecos + @pytest.mark.parametrize("is_async", [False, True]) def test_kv_connector_basic(is_async: bool): @@ -1286,29 +1288,72 @@ def test_kv_connector_basic(is_async: bool): @pytest.mark.parametrize("is_async", [False, True]) -def test_external_prefix_cache_metrics(is_async: bool): +@pytest.mark.parametrize("local_cache_hits", [False, True]) +def test_external_prefix_cache_metrics(is_async: bool, local_cache_hits: bool): """ Verify connector prefix cache metrics are updated correctly when the scheduler processes requests with KV connector hits. """ + BLOCK_SIZE = 16 + if local_cache_hits: + NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2 # 32 tokens + NUM_LOCAL_HITS = NUM_MATCHED_NEW_TOKENS * 2 # 64 tokens + NUM_REQUESTS = 1 + NUM_TOKENS = NUM_LOCAL_HITS * 2 # 128 tokens + else: + NUM_MATCHED_NEW_TOKENS = 4 + NUM_LOCAL_HITS = 0 + NUM_REQUESTS = 2 + NUM_TOKENS = 8 # 8 tokens + # Setup Scheduler. - NUM_MATCHED_NEW_TOKENS = 4 scheduler = create_scheduler( - enable_prefix_caching=False, + enable_prefix_caching=local_cache_hits, use_kv_connector=mock_kv( matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async ), + block_size=BLOCK_SIZE, ) - # --- Prepare simple requests --- - NUM_REQUESTS = 2 - NUM_TOKENS = 8 + if local_cache_hits: + # First, establish local cache by running a request to completion + requests = create_requests( + num_requests=1, + num_tokens=NUM_LOCAL_HITS, + max_tokens=2, + block_size=BLOCK_SIZE, + ) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + if is_async: + _step_until_kv_transfer_finished(scheduler, req_ids) + + # Run first request to completion to establish local cache + output = scheduler.schedule() + MODEL_RUNNER_OUTPUT = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[1000]] * len(req_ids), + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + _step_until_done(scheduler, output, MODEL_RUNNER_OUTPUT) + _ = scheduler.schedule() + + # --- Prepare test requests --- MAX_TOKENS = 2 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, max_tokens=MAX_TOKENS, + block_size=BLOCK_SIZE, ) req_ids = [] req_to_index = {} @@ -1317,8 +1362,9 @@ def test_external_prefix_cache_metrics(is_async: bool): req_ids.append(request.request_id) req_to_index[request.request_id] = i + initial_ecos = None if is_async: - _step_until_kv_transfer_finished(scheduler, req_ids) + initial_ecos = _step_until_kv_transfer_finished(scheduler, req_ids) # --- Trigger scheduling and simulate model output --- output = scheduler.schedule() @@ -1338,10 +1384,23 @@ def test_external_prefix_cache_metrics(is_async: bool): assert ecos is not None and len(ecos) > 0 assert ecos[0].scheduler_stats is not None - external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats + if local_cache_hits: + # For async, local cache stats come from the first step + if initial_ecos: + local_stats = initial_ecos[0].scheduler_stats.prefix_cache_stats + else: + local_stats = ecos[0].scheduler_stats.prefix_cache_stats + assert local_stats is not None + assert local_stats.queries == NUM_TOKENS * NUM_REQUESTS + assert local_stats.hits == NUM_LOCAL_HITS * NUM_REQUESTS + + if initial_ecos: + external_stats = initial_ecos[0].scheduler_stats.connector_prefix_cache_stats + else: + external_stats = ecos[0].scheduler_stats.connector_prefix_cache_stats assert external_stats is not None - assert external_stats.queries == NUM_TOKENS * NUM_REQUESTS + assert external_stats.queries == (NUM_TOKENS - NUM_LOCAL_HITS) * NUM_REQUESTS assert external_stats.hits == NUM_MATCHED_NEW_TOKENS * NUM_REQUESTS assert external_stats.requests == NUM_REQUESTS assert external_stats.preempted_requests == 0 diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py index 940f3a983..6cb2d3ea4 100644 --- a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -281,6 +281,17 @@ def test_sync_fail_invalid_blocks_evicted(fail_scheduler: Scheduler): f"(hash should be None), but hash is still {block.block_hash}" ) + # Verify connector prefix cache stats: + # - queries = num_prompt_tokens (total tokens not in local cache) + # - hits = num_external_computed_tokens (tokens loaded externally) + assert engine_outputs.scheduler_stats is not None + stats = engine_outputs.scheduler_stats + assert stats.connector_prefix_cache_stats is not None + conn_stats = stats.connector_prefix_cache_stats + assert conn_stats.requests == 1 + assert conn_stats.queries == num_prompt_tokens + assert conn_stats.hits == num_external_computed_tokens + def test_async_recompute_blocks_not_cached_when_invalid( recompute_scheduler: Scheduler, @@ -364,7 +375,9 @@ def test_async_recompute_blocks_not_cached_when_invalid( with patch.object( recompute_scheduler.kv_cache_manager, "evict_blocks", evict_blocks_spy ): - recompute_scheduler.update_from_output(scheduler_output, model_runner_output) + outputs = recompute_scheduler.update_from_output( + scheduler_output, model_runner_output + ) # verify evict_blocks was NOT called (async blocks excluded from eviction) assert len(evict_blocks_calls) == 0, ( @@ -386,6 +399,19 @@ def test_async_recompute_blocks_not_cached_when_invalid( f"Block {invalid_block_id} hash should be None but is {block.block_hash}" ) + # Verify connector prefix cache stats: + # - queries = num_prompt_tokens (total tokens not in local cache) + # - hits = num_external_computed_tokens (tokens loaded externally) + assert len(outputs) == 1 + engine_outputs = next(iter(outputs.values())) + assert engine_outputs.scheduler_stats is not None + stats = engine_outputs.scheduler_stats + assert stats.connector_prefix_cache_stats is not None + conn_stats = stats.connector_prefix_cache_stats + assert conn_stats.requests == 1 + assert conn_stats.queries == num_prompt_tokens + assert conn_stats.hits == num_external_computed_tokens + # now simulate async transfer completing model_runner_output_2 = create_model_runner_output( reqs=[], diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 88d1a78df..745d9ffec 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -586,6 +586,7 @@ class Scheduler(SchedulerInterface): num_external_computed_tokens = 0 load_kv_async = False + connector_prefix_cache_queries, connector_prefix_cache_hits = 0, 0 # Get already-cached tokens. if request.num_computed_tokens == 0: @@ -613,6 +614,11 @@ class Scheduler(SchedulerInterface): request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens + connector_prefix_cache_queries = ( + request.num_tokens - num_new_local_computed_tokens + ) + connector_prefix_cache_hits = num_external_computed_tokens + # Total computed tokens (local + external). num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens @@ -728,6 +734,15 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager.get_blocks(request_id), num_external_computed_tokens, ) + if ( + self.connector_prefix_cache_stats is not None + and connector_prefix_cache_queries != 0 + ): + self.connector_prefix_cache_stats.record( + num_tokens=connector_prefix_cache_queries, + num_hits=connector_prefix_cache_hits, + preempted=request.num_preemptions > 0, + ) # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -739,8 +754,6 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue - self._update_connector_prefix_cache_stats(request) - self.running.append(request) if self.log_stats: request.record_event( @@ -1805,7 +1818,10 @@ class Scheduler(SchedulerInterface): return None prefix_cache_stats = self.kv_cache_manager.make_prefix_cache_stats() assert prefix_cache_stats is not None - connector_prefix_cache_stats = self._make_connector_prefix_cache_stats() + connector_prefix_cache_stats: PrefixCacheStats | None = None + if self.connector_prefix_cache_stats is not None: + connector_prefix_cache_stats = self.connector_prefix_cache_stats + self.connector_prefix_cache_stats = PrefixCacheStats() eviction_events = ( self.kv_metrics_collector.drain_events() if self.kv_metrics_collector is not None @@ -1866,23 +1882,6 @@ class Scheduler(SchedulerInterface): # KV Connector Related Methods ######################################################################## - def _update_connector_prefix_cache_stats(self, request: Request) -> None: - if self.connector_prefix_cache_stats is None: - return - - self.connector_prefix_cache_stats.record( - num_tokens=request.num_tokens, - num_hits=request.num_external_computed_tokens, - preempted=request.num_preemptions > 0, - ) - - def _make_connector_prefix_cache_stats(self) -> PrefixCacheStats | None: - if self.connector_prefix_cache_stats is None: - return None - stats = self.connector_prefix_cache_stats - self.connector_prefix_cache_stats = PrefixCacheStats() - return stats - def get_kv_connector(self) -> KVConnectorBase_V1 | None: return self.connector