[Feat][RL] Pause and Resume with keep requests for single engine (#32351)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Aaron Hao
2026-02-06 16:08:58 -08:00
committed by GitHub
parent 4a2d00eafd
commit 89a385d79f
8 changed files with 536 additions and 30 deletions

View File

@@ -4,7 +4,7 @@
import enum
import time
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal
import msgspec
import numpy as np
@@ -18,6 +18,12 @@ from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult
# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")

View File

@@ -38,7 +38,7 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_va
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine import EngineCoreRequest, PauseMode
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
@@ -170,6 +170,7 @@ class AsyncLLM(EngineClient):
# Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition()
self._paused = False
self._client_count = client_count
self.output_handler: asyncio.Task | None = None
try:
@@ -728,7 +729,8 @@ class AsyncLLM(EngineClient):
async def pause_generation(
self,
*,
wait_for_inflight_requests: bool = False,
mode: PauseMode = "abort",
wait_for_inflight_requests: bool | None = None,
clear_cache: bool = True,
) -> None:
"""
@@ -737,27 +739,52 @@ class AsyncLLM(EngineClient):
New generation/encoding requests are blocked until resume.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
immediately aborts any in-flight requests.
mode: How to handle in-flight requests:
- ``"abort"``: Abort all in-flight requests immediately
(default).
- ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED: use mode argument.
Whether to wait for in-flight requests to complete before pausing.
clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches).
"""
if wait_for_inflight_requests:
warnings.warn(
"The `wait_for_inflight_requests` parameter in "
"`AsyncLLM.pause_generation()` is deprecated. "
"Please use `mode` argument instead.",
DeprecationWarning,
stacklevel=2,
)
mode = "wait"
async with self._pause_cond:
if self._paused:
return
self._paused = True
if mode == "keep":
# Freeze requests in the scheduler - they will resume on
# resume_generation().
await self.engine_core.pause_scheduler_async()
else:
if self._client_count > 1:
raise NotImplementedError(
"pause_generation is not supported with --api-server-count > 1"
" when mode is not 'keep'"
)
async with self._pause_cond:
if not self._paused:
self._paused = True
if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
if mode == "abort":
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
elif mode == "wait":
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
else:
raise ValueError(f"Invalid mode: {mode}")
# Clear cache
if clear_cache:
@@ -769,6 +796,7 @@ class AsyncLLM(EngineClient):
"""Resume generation after :meth:`pause_generation`."""
async with self._pause_cond:
await self.engine_core.resume_scheduler_async()
self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests

View File

@@ -209,6 +209,10 @@ class EngineCore:
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
self.aborts_queue = queue.Queue[list[str]]()
# Pause state for "keep" mode - freezes requests in queue.
self._scheduler_paused = False
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
@@ -322,6 +326,20 @@ class EngineCore:
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
def pause_scheduler(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Requests are kept frozen in queue and can be resumed later.
"""
self._scheduler_paused = True
def resume_scheduler(self) -> None:
"""Resume the scheduler after a pause.
Resumes processing of frozen requests in the queue.
"""
self._scheduler_paused = False
@contextmanager
def log_error_detail(self, scheduler_output: SchedulerOutput):
"""Execute the model and log detailed info on failure."""
@@ -375,6 +393,10 @@ class EngineCore:
was executed.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
@@ -425,6 +447,10 @@ class EngineCore:
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
batch_queue = self.batch_queue
assert batch_queue is not None
@@ -1007,6 +1033,7 @@ class EngineCoreProc(EngineCore):
not self.engines_running
and not self.scheduler.has_requests()
and not self.batch_queue
and not self._scheduler_paused
):
if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue.

View File

@@ -105,7 +105,7 @@ class EngineCoreClient(ABC):
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
) -> "AsyncMPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
@@ -976,6 +976,16 @@ class AsyncMPClient(MPClient):
if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Blocks until the EngineCore acknowledges the pause.
"""
await self.call_utility_async("pause_scheduler")
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
await self.call_utility_async("resume_scheduler")
async def profile_async(self, is_start: bool = True) -> None:
await self.call_utility_async("profile", is_start)
@@ -1188,6 +1198,18 @@ class DPAsyncMPClient(AsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue."""
raise NotImplementedError(
"pause_scheduler_async is not yet supported for data parallel"
)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
raise NotImplementedError(
"resume_scheduler_async is not yet supported for data parallel"
)
class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)