[Bugfix] Do not crash V0 engine on input errors (#13101)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2025-02-26 04:07:29 -07:00
committed by GitHub
parent ec8a5e5386
commit 3f808cc044
5 changed files with 172 additions and 6 deletions

View File

@@ -60,6 +60,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@@ -410,6 +411,10 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self._skip_scheduling_next_step = False
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@@ -1334,7 +1339,11 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if not self._has_remaining_steps(
seq_group_metadata_list
) and not self._skip_scheduling_next_step:
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
@@ -1388,8 +1397,23 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
try:
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
self._skip_scheduling_next_step = False
except InputProcessingError as e:
# The input for this request cannot be processed, so we must
# abort it. If there are remaining requests in the batch that
# have been scheduled, they will be retried on the next step.
invalid_request_id = e.request_id
self._abort_and_cache_schedule(
request_id=invalid_request_id,
virtual_engine=virtual_engine,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
allow_async_output_proc=allow_async_output_proc)
# Raise so the caller is notified that this request failed
raise
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
@@ -1464,6 +1488,38 @@ class LLMEngine:
return ctx.request_outputs
def _abort_and_cache_schedule(
self, request_id: str, virtual_engine: int,
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs,
allow_async_output_proc: bool) -> None:
"""Aborts a single request, and caches the scheduler outputs minus that
request. This allows the next step to continue processing the remaining
requests without having to re-run the scheduler."""
# Abort the request and remove its sequence group from the current
# schedule
self.abort_request(request_id)
for i, metadata in enumerate(seq_group_metadata_list):
if metadata.request_id == request_id:
del seq_group_metadata_list[i]
break
for i, group in enumerate(scheduler_outputs.scheduled_seq_groups):
if group.seq_group.request_id == request_id:
del scheduler_outputs.scheduled_seq_groups[i]
break
# If there are still other sequence groups left in the schedule, cache
# them and flag the engine to reuse the schedule.
if len(seq_group_metadata_list) > 0:
self._skip_scheduling_next_step = True
# Reuse multi-step caching logic
self._cache_scheduler_outputs_for_multi_step(
virtual_engine=virtual_engine,
scheduler_outputs=scheduler_outputs,
seq_group_metadata_list=seq_group_metadata_list,
allow_async_output_proc=allow_async_output_proc)
def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool: