[P/D] KV Load Failure Recovery/Abort Configuration (#26813)
Signed-off-by: Will Eaton <weaton@redhat.com> Signed-off-by: Will Eaton <me@wseaton.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -397,6 +397,25 @@ class BlockPool:
|
||||
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
|
||||
)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
only evicts blocks that are currently cached (have a hash). blocks
|
||||
with ref_cnt > 0 are not freed from the block pool, only evicted
|
||||
from the prefix cache hash table.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
for block_id in block_ids:
|
||||
assert block_id < len(self.blocks), (
|
||||
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
|
||||
f"This indicates a bug in the KV connector - workers should "
|
||||
f"only report block IDs that were allocated by the scheduler."
|
||||
)
|
||||
block = self.blocks[block_id]
|
||||
self._maybe_evict_cached_block(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
|
||||
@@ -333,6 +333,14 @@ class KVCacheManager:
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
self.block_pool.evict_blocks(block_ids)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
|
||||
@@ -106,6 +106,7 @@ class Scheduler(SchedulerInterface):
|
||||
# KV Connector pushes/pull of remote KVs for P/D and offloading.
|
||||
self.connector = None
|
||||
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
|
||||
self.recompute_kv_load_failures = True
|
||||
if self.vllm_config.kv_transfer_config is not None:
|
||||
assert not self.is_encoder_decoder, (
|
||||
"Encoder-decoder models are not currently supported with KV connectors"
|
||||
@@ -117,6 +118,10 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
if self.log_stats:
|
||||
self.connector_prefix_cache_stats = PrefixCacheStats()
|
||||
kv_load_failure_policy = (
|
||||
self.vllm_config.kv_transfer_config.kv_load_failure_policy
|
||||
)
|
||||
self.recompute_kv_load_failures = kv_load_failure_policy == "recompute"
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
self.kv_events_config,
|
||||
@@ -1066,7 +1071,7 @@ class Scheduler(SchedulerInterface):
|
||||
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
|
||||
assert num_tokens_scheduled > 0
|
||||
if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids:
|
||||
# Skip requests that were recovered from KV load failure
|
||||
# skip failed or rescheduled requests from KV load failure
|
||||
continue
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
@@ -1177,6 +1182,21 @@ class Scheduler(SchedulerInterface):
|
||||
# This is a rare case and unlikely to impact performance.
|
||||
self.waiting.remove_requests(stopped_preempted_reqs)
|
||||
|
||||
if failed_kv_load_req_ids and not self.recompute_kv_load_failures:
|
||||
requests = [self.requests[req_id] for req_id in failed_kv_load_req_ids]
|
||||
self.finish_requests(failed_kv_load_req_ids, RequestStatus.FINISHED_ERROR)
|
||||
for request in requests:
|
||||
outputs[request.client_index].append(
|
||||
EngineCoreOutput(
|
||||
request_id=request.request_id,
|
||||
new_token_ids=[],
|
||||
finish_reason=request.get_finished_reason(),
|
||||
events=request.take_events(),
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
# KV Connector: update state for finished KV Transfers.
|
||||
if kv_connector_output:
|
||||
self._update_from_kv_xfer_finished(kv_connector_output)
|
||||
@@ -1610,8 +1630,11 @@ class Scheduler(SchedulerInterface):
|
||||
self._free_blocks(self.requests[req_id])
|
||||
|
||||
def _update_requests_with_invalid_blocks(
|
||||
self, requests: Iterable[Request], invalid_block_ids: set[int]
|
||||
) -> tuple[set[str], int]:
|
||||
self,
|
||||
requests: Iterable[Request],
|
||||
invalid_block_ids: set[int],
|
||||
evict_blocks: bool = True,
|
||||
) -> tuple[set[str], int, set[int]]:
|
||||
"""
|
||||
Identify and update requests affected by invalid KV cache blocks.
|
||||
|
||||
@@ -1623,16 +1646,21 @@ class Scheduler(SchedulerInterface):
|
||||
Args:
|
||||
requests: The set of requests to scan for invalid blocks.
|
||||
invalid_block_ids: IDs of invalid blocks.
|
||||
evict_blocks: Whether to collect blocks for eviction (False for
|
||||
async requests which aren't cached yet).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- affected_req_ids (set[str]): IDs of requests impacted by
|
||||
invalid blocks.
|
||||
- total_affected_tokens (int): Total number of tokens that must
|
||||
be recomputed across all affected requests (for observability).
|
||||
be recomputed across all affected requests.
|
||||
- blocks_to_evict (set[int]): Block IDs to evict from cache,
|
||||
including invalid blocks and downstream dependent blocks.
|
||||
"""
|
||||
affected_req_ids: set[str] = set()
|
||||
total_affected_tokens = 0
|
||||
blocks_to_evict: set[int] = set()
|
||||
# If a block is invalid and shared by multiple requests in the batch,
|
||||
# these requests must be rescheduled, but only the first will recompute
|
||||
# it. This set tracks blocks already marked for recomputation.
|
||||
@@ -1690,6 +1718,9 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
total_affected_tokens += num_affected_tokens
|
||||
request.num_external_computed_tokens -= num_affected_tokens
|
||||
# collect invalid block and all downstream dependent blocks
|
||||
if evict_blocks:
|
||||
blocks_to_evict.update(req_block_ids[idx:])
|
||||
|
||||
if is_affected:
|
||||
if not marked_invalid_block:
|
||||
@@ -1705,47 +1736,70 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
affected_req_ids.add(request.request_id)
|
||||
|
||||
return affected_req_ids, total_affected_tokens
|
||||
return affected_req_ids, total_affected_tokens, blocks_to_evict
|
||||
|
||||
def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]:
|
||||
total_requests_to_reschedule = 0
|
||||
total_tokens_to_reschedule = 0
|
||||
"""
|
||||
Handle requests affected by invalid KV cache blocks.
|
||||
|
||||
# --- Handle async KV loads (WAITING_FOR_REMOTE_KVS) ---
|
||||
Returns:
|
||||
Set of affected request IDs to skip in update_from_output main loop.
|
||||
"""
|
||||
should_fail = not self.recompute_kv_load_failures
|
||||
|
||||
# handle async KV loads (not cached yet, evict_blocks=False)
|
||||
async_load_reqs = (
|
||||
req
|
||||
for req in self.waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
)
|
||||
async_affected_req_ids, num_tokens_to_reschedule = (
|
||||
async_failed_req_ids, num_failed_tokens, _ = (
|
||||
self._update_requests_with_invalid_blocks(
|
||||
async_load_reqs, invalid_block_ids
|
||||
async_load_reqs, invalid_block_ids, evict_blocks=False
|
||||
)
|
||||
)
|
||||
|
||||
total_requests_to_reschedule += len(async_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
total_failed_requests = len(async_failed_req_ids)
|
||||
total_failed_tokens = num_failed_tokens
|
||||
|
||||
# Mark requests with async KV load failures; they will be rescheduled
|
||||
# once loading completes.
|
||||
self.failed_recving_kv_req_ids |= async_affected_req_ids
|
||||
|
||||
# --- Handle sync KV loads (running requests) ---
|
||||
sync_affected_req_ids, num_tokens_to_reschedule = (
|
||||
self._update_requests_with_invalid_blocks(self.running, invalid_block_ids)
|
||||
# handle sync loads (may be cached, collect blocks for eviction)
|
||||
sync_failed_req_ids, num_failed_tokens, sync_blocks_to_evict = (
|
||||
self._update_requests_with_invalid_blocks(
|
||||
self.running, invalid_block_ids, evict_blocks=True
|
||||
)
|
||||
)
|
||||
|
||||
total_requests_to_reschedule += len(sync_affected_req_ids)
|
||||
total_tokens_to_reschedule += num_tokens_to_reschedule
|
||||
total_failed_requests += len(sync_failed_req_ids)
|
||||
total_failed_tokens += num_failed_tokens
|
||||
|
||||
if total_requests_to_reschedule:
|
||||
logger.warning(
|
||||
"Recovered from KV load failure: "
|
||||
"%d request(s) rescheduled (%d tokens affected).",
|
||||
total_requests_to_reschedule,
|
||||
total_tokens_to_reschedule,
|
||||
if not total_failed_requests:
|
||||
return set()
|
||||
|
||||
# evict invalid blocks and downstream dependent blocks from cache
|
||||
# only when not using recompute policy (where blocks will be recomputed
|
||||
# and reused by other requests sharing them)
|
||||
if sync_blocks_to_evict and not self.recompute_kv_load_failures:
|
||||
self.kv_cache_manager.evict_blocks(sync_blocks_to_evict)
|
||||
|
||||
if should_fail:
|
||||
all_failed_req_ids = async_failed_req_ids | sync_failed_req_ids
|
||||
logger.error(
|
||||
"Failing %d request(s) due to KV load failure "
|
||||
"(failure_policy=fail, %d tokens affected). Request IDs: %s",
|
||||
total_failed_requests,
|
||||
total_failed_tokens,
|
||||
all_failed_req_ids,
|
||||
)
|
||||
return all_failed_req_ids
|
||||
|
||||
# Return the IDs of affected running requests to skip in
|
||||
# update_from_output.
|
||||
return sync_affected_req_ids
|
||||
logger.warning(
|
||||
"Recovered from KV load failure: "
|
||||
"%d request(s) rescheduled (%d tokens affected).",
|
||||
total_failed_requests,
|
||||
total_failed_tokens,
|
||||
)
|
||||
|
||||
# Mark async requests with KV load failures for retry once loading completes
|
||||
self.failed_recving_kv_req_ids |= async_failed_req_ids
|
||||
# Return sync affected IDs to skip in update_from_output
|
||||
return sync_failed_req_ids
|
||||
|
||||
Reference in New Issue
Block a user