[KV Connector] Fix async connector prefix cache metrics (#28585)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-11-21 22:45:00 +00:00
committed by GitHub
parent 3137991f55
commit c6fa3895e9
3 changed files with 24 additions and 12 deletions

View File

@@ -1057,7 +1057,8 @@ def test_kv_connector_basic(is_async: bool):
) )
def test_external_prefix_cache_metrics(): @pytest.mark.parametrize("is_async", [False, True])
def test_external_prefix_cache_metrics(is_async: bool):
""" """
Verify connector prefix cache metrics are updated Verify connector prefix cache metrics are updated
correctly when the scheduler processes requests with KV connector hits. correctly when the scheduler processes requests with KV connector hits.
@@ -1067,7 +1068,9 @@ def test_external_prefix_cache_metrics():
NUM_MATCHED_NEW_TOKENS = 4 NUM_MATCHED_NEW_TOKENS = 4
scheduler = create_scheduler( scheduler = create_scheduler(
enable_prefix_caching=False, enable_prefix_caching=False,
use_kv_connector=mock_kv(matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=False), use_kv_connector=mock_kv(
matched_tokens=NUM_MATCHED_NEW_TOKENS, is_async=is_async
),
) )
# --- Prepare simple requests --- # --- Prepare simple requests ---
@@ -1079,9 +1082,15 @@ def test_external_prefix_cache_metrics():
num_tokens=NUM_TOKENS, num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS, max_tokens=MAX_TOKENS,
) )
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
for req in requests: if is_async:
scheduler.add_request(req) _step_until_kv_transfer_finished(scheduler, req_ids)
# --- Trigger scheduling and simulate model output --- # --- Trigger scheduling and simulate model output ---
output = scheduler.schedule() output = scheduler.schedule()

View File

@@ -470,6 +470,7 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
request.num_external_computed_tokens = ext_tokens
num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
@@ -576,9 +577,6 @@ class Scheduler(SchedulerInterface):
new_computed_blocks + new_blocks, new_computed_blocks + new_blocks,
num_external_computed_tokens, num_external_computed_tokens,
) )
self._update_connector_prefix_cache_stats(
request, num_external_computed_tokens
)
# Request was already popped from self.waiting # Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None. # unless it was re-added above due to new_blocks being None.
@@ -590,6 +588,8 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue continue
self._update_connector_prefix_cache_stats(request)
req_index += 1 req_index += 1
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
@@ -1380,15 +1380,13 @@ class Scheduler(SchedulerInterface):
# KV Connector Related Methods # KV Connector Related Methods
######################################################################## ########################################################################
def _update_connector_prefix_cache_stats( def _update_connector_prefix_cache_stats(self, request: Request) -> None:
self, request: Request, num_external_tokens: int
) -> None:
if self.connector_prefix_cache_stats is None: if self.connector_prefix_cache_stats is None:
return return
self.connector_prefix_cache_stats.record( self.connector_prefix_cache_stats.record(
num_tokens=request.num_tokens, num_tokens=request.num_tokens,
num_hits=num_external_tokens, num_hits=request.num_external_computed_tokens,
preempted=request.num_preemptions > 0, preempted=request.num_preemptions > 0,
) )
@@ -1571,9 +1569,11 @@ class Scheduler(SchedulerInterface):
marked_invalid_block = True marked_invalid_block = True
# Truncate the computed tokens at the first failed block # Truncate the computed tokens at the first failed block
request.num_computed_tokens = idx * self.block_size request.num_computed_tokens = idx * self.block_size
total_affected_tokens += ( num_affected_tokens = (
req_num_computed_tokens - request.num_computed_tokens req_num_computed_tokens - request.num_computed_tokens
) )
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
if is_affected: if is_affected:
if not marked_invalid_block: if not marked_invalid_block:

View File

@@ -121,6 +121,9 @@ class Request:
# The number of requests being preempted by the scheduler # The number of requests being preempted by the scheduler
self.num_preemptions = 0 self.num_preemptions = 0
# The number of tokens that have been computed remotely.
self.num_external_computed_tokens = 0
self.block_hashes: list[BlockHash] = [] self.block_hashes: list[BlockHash] = []
self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None
if block_hasher is not None: if block_hasher is not None: