[P/D] KV Load Failure Recovery/Abort Configuration (#26813)

Signed-off-by: Will Eaton <weaton@redhat.com>
Signed-off-by: Will Eaton <me@wseaton.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Will Eaton
2025-12-10 14:00:52 -05:00
committed by GitHub
parent e8e8cd73e5
commit a9e4106f28
16 changed files with 1552 additions and 48 deletions

View File

@@ -397,6 +397,25 @@ class BlockPool:
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
only evicts blocks that are currently cached (have a hash). blocks
with ref_cnt > 0 are not freed from the block pool, only evicted
from the prefix cache hash table.
Args:
block_ids: Set of block IDs to evict from cache.
"""
for block_id in block_ids:
assert block_id < len(self.blocks), (
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
f"This indicates a bug in the KV connector - workers should "
f"only report block IDs that were allocated by the scheduler."
)
block = self.blocks[block_id]
self._maybe_evict_cached_block(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,

View File

@@ -333,6 +333,14 @@ class KVCacheManager:
"""
self.coordinator.free(request.request_id)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
Args:
block_ids: Set of block IDs to evict from cache.
"""
self.block_pool.evict_blocks(block_ids)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,

View File

@@ -106,6 +106,7 @@ class Scheduler(SchedulerInterface):
# KV Connector pushes/pull of remote KVs for P/D and offloading.
self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
self.recompute_kv_load_failures = True
if self.vllm_config.kv_transfer_config is not None:
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
@@ -117,6 +118,10 @@ class Scheduler(SchedulerInterface):
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
kv_load_failure_policy = (
self.vllm_config.kv_transfer_config.kv_load_failure_policy
)
self.recompute_kv_load_failures = kv_load_failure_policy == "recompute"
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config,
@@ -1066,7 +1071,7 @@ class Scheduler(SchedulerInterface):
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
# Skip requests that were recovered from KV load failure
# skip failed or rescheduled requests from KV load failure
continue
request = self.requests.get(req_id)
if request is None:
@@ -1177,6 +1182,21 @@ class Scheduler(SchedulerInterface):
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
for request in requests:
outputs[request.client_index].append(
EngineCoreOutput(
request_id=request.request_id,
new_token_ids=[],
finish_reason=request.get_finished_reason(),
events=request.take_events(),
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
)
)
# KV Connector: update state for finished KV Transfers.
if kv_connector_output:
self._update_from_kv_xfer_finished(kv_connector_output)
@@ -1610,8 +1630,11 @@ class Scheduler(SchedulerInterface):
self._free_blocks(self.requests[req_id])
def _update_requests_with_invalid_blocks(
self, requests: Iterable[Request], invalid_block_ids: set[int]
) -> tuple[set[str], int]:
self,
requests: Iterable[Request],
invalid_block_ids: set[int],
evict_blocks: bool = True,
) -> tuple[set[str], int, set[int]]:
"""
Identify and update requests affected by invalid KV cache blocks.
@@ -1623,16 +1646,21 @@ class Scheduler(SchedulerInterface):
Args:
requests: The set of requests to scan for invalid blocks.
invalid_block_ids: IDs of invalid blocks.
evict_blocks: Whether to collect blocks for eviction (False for
async requests which aren't cached yet).
Returns:
tuple:
- affected_req_ids (set[str]): IDs of requests impacted by
invalid blocks.
- total_affected_tokens (int): Total number of tokens that must
be recomputed across all affected requests (for observability).
be recomputed across all affected requests.
- blocks_to_evict (set[int]): Block IDs to evict from cache,
including invalid blocks and downstream dependent blocks.
"""
affected_req_ids: set[str] = set()
total_affected_tokens = 0
blocks_to_evict: set[int] = set()
# If a block is invalid and shared by multiple requests in the batch,
# these requests must be rescheduled, but only the first will recompute
# it. This set tracks blocks already marked for recomputation.
@@ -1690,6 +1718,9 @@ class Scheduler(SchedulerInterface):
)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
if is_affected:
if not marked_invalid_block:
@@ -1705,47 +1736,70 @@ class Scheduler(SchedulerInterface):
affected_req_ids.add(request.request_id)
return affected_req_ids, total_affected_tokens
return affected_req_ids, total_affected_tokens, blocks_to_evict
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
total_requests_to_reschedule = 0
total_tokens_to_reschedule = 0
"""
Handle requests affected by invalid KV cache blocks.
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
Returns:
Set of affected request IDs to skip in update_from_output main loop.
"""
should_fail = not self.recompute_kv_load_failures
# handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = (
req
for req in self.waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
)
async_affected_req_ids, num_tokens_to_reschedule = (
async_failed_req_ids, num_failed_tokens, _ = (
self._update_requests_with_invalid_blocks(
async_load_reqs, invalid_block_ids
async_load_reqs, invalid_block_ids, evict_blocks=False
)
)
total_requests_to_reschedule += len(async_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
total_failed_requests = len(async_failed_req_ids)
total_failed_tokens = num_failed_tokens
# Mark requests with async KV load failures; they will be rescheduled
# once loading completes.
self.failed_recving_kv_req_ids |= async_affected_req_ids
# --- Handle sync KV loads (running requests) ---
sync_affected_req_ids, num_tokens_to_reschedule = (
self._update_requests_with_invalid_blocks(self.running, invalid_block_ids)
# handle sync loads (may be cached, collect blocks for eviction)
sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = (
self._update_requests_with_invalid_blocks(
self.running, invalid_block_ids, evict_blocks=True
)
)
total_requests_to_reschedule += len(sync_affected_req_ids)
total_tokens_to_reschedule += num_tokens_to_reschedule
total_failed_requests += len(sync_failed_req_ids)
total_failed_tokens += num_failed_tokens
if total_requests_to_reschedule:
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_requests_to_reschedule,
total_tokens_to_reschedule,
if not total_failed_requests:
return set()
# evict invalid blocks and downstream dependent blocks from cache
# only when not using recompute policy (where blocks will be recomputed
# and reused by other requests sharing them)
if sync_blocks_to_evict and not self.recompute_kv_load_failures:
self.kv_cache_manager.evict_blocks(sync_blocks_to_evict)
if should_fail:
all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids
logger.error(
"Failing %d request(s) due to KV load failure "
"(failure_policy=fail, %d tokens affected). Request IDs: %s",
total_failed_requests,
total_failed_tokens,
all_failed_req_ids,
)
return all_failed_req_ids
# Return the IDs of affected running requests to skip in
# update_from_output.
return sync_affected_req_ids
logger.warning(
"Recovered from KV load failure: "
"%d request(s) rescheduled (%d tokens affected).",
total_failed_requests,
total_failed_tokens,
)
# Mark async requests with KV load failures for retry once loading completes
self.failed_recving_kv_req_ids |= async_failed_req_ids
# Return sync affected IDs to skip in update_from_output
return sync_failed_req_ids