[V1] Exception Handling when Loading KV Cache from Remote Store (#21534)

Signed-off-by: liuyumoye <adeline_ly2023@outlook.com>
Co-authored-by: liuyumoye <adeline_ly2023@outlook.com>
This commit is contained in:
Adeline
2025-07-28 11:34:17 +08:00
committed by GitHub
parent 04ff4be310
commit 15a72ac478
10 changed files with 229 additions and 5 deletions

View File

@@ -118,6 +118,9 @@ class Scheduler(SchedulerInterface):
# KV Connector: requests in process of async KV loading or recving
self.finished_recving_kv_req_ids: set[str] = set()
# The keys are request IDs, and the values are corresponding token
# count that have been successfully loaded from the remote KV store
self.finished_loading_dict: dict[str, int] = {}
# Encoder-related.
# Calculate encoder cache size if applicable
@@ -1094,6 +1097,27 @@ class Scheduler(SchedulerInterface):
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
def _update_actual_load_token_num_from_remote_kv(self,
request: Request) -> bool:
num_actual_load_tokens = self.finished_loading_dict.pop(
request.request_id)
num_computed_tokens = num_actual_load_tokens
assert self.connector is not None
if num_actual_load_tokens <= 0 and hasattr(self.connector,
"add_failure_request"):
self.connector.add_failure_request(request)
return True
if num_actual_load_tokens == request.num_tokens:
num_computed_tokens -= 1
self.kv_cache_manager.cache_blocks(request, num_computed_tokens)
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
return True
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""
KV Connector: check if the request_id is finished_recving.
@@ -1107,6 +1131,9 @@ class Scheduler(SchedulerInterface):
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id in self.finished_loading_dict:
return self._update_actual_load_token_num_from_remote_kv(request)
if request.request_id not in self.finished_recving_kv_req_ids:
return False
@@ -1145,3 +1172,6 @@ class Scheduler(SchedulerInterface):
for req_id in (model_runner_output.finished_sending or ()):
logger.debug("Finished sending KV transfer for request %s", req_id)
self._free_blocks(self.requests[req_id])
if model_runner_output.finished_loading_dict:
self.finished_loading_dict.update(
model_runner_output.finished_loading_dict)