[Feat][RL] Pause and Resume with keep requests for single engine (#32351)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Aaron Hao
2026-02-06 16:08:58 -08:00
committed by GitHub
parent 4a2d00eafd
commit 89a385d79f
8 changed files with 536 additions and 30 deletions

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test for pause/resume with keep mode.
This test uses concurrent tasks to verify the engine truly stops generating
during pause:
1. Generator task: continuously generates and logs time between tokens
2. Controller task: sends pause/resume commands
If the engine properly pauses, we should see a gap in token timestamps
matching the pause duration.
"""
import asyncio
import time
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
PAUSE_DURATION = 3.0 # seconds
async def main():
# Create engine with a small model
engine_args = AsyncEngineArgs(
model="facebook/opt-125m",
enforce_eager=True,
)
engine = AsyncLLM.from_engine_args(engine_args)
prompt = "Write a story about a dragon. Once upon a time"
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
# Track token arrival times
token_times: list[tuple[int, float]] = [] # (token_count, timestamp)
pause_time: float = 0
resume_time: float = 0
pause_token_idx: int = 0 # Index in token_times when pause occurred
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-req",
prompt=prompt,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
print(
f"Token {token_count} arrived:"
f"T={token_times[-1][1] - token_times[0][1]:.3f}s"
)
return output
async def controller_task():
"""Pause and resume the engine after some tokens generated."""
nonlocal pause_time, resume_time, pause_token_idx
# Wait for some tokens to be generated
while len(token_times) < 5:
await asyncio.sleep(0.01)
print(f"\nPausing engine (keep mode) at token {len(token_times)}")
pause_time = time.monotonic()
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
print(f"Paused! Sleeping for {PAUSE_DURATION}s...")
# Sleep while paused - no tokens should be generated during this time
await asyncio.sleep(PAUSE_DURATION)
print("Resuming engine...")
await engine.resume_generation()
resume_time = time.monotonic()
print("Resumed!\n")
# Run both tasks concurrently
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
# Verify the pause actually stopped generation.
# The gap after the pause token should be approximately the sleep duration.
pause_gap = token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
print(
f"\nGap after pause (token {pause_token_idx - 1} -> {pause_token_idx}): "
f"{pause_gap:.3f}s"
)
if pause_gap >= PAUSE_DURATION * 0.9:
print(f"✓ Test passed! Engine paused for ~{pause_gap:.1f}s")
else:
print(
f"✗ Test failed! Expected ~{PAUSE_DURATION}s gap after pause, "
f"got {pause_gap:.3f}s"
)
raise AssertionError("Engine did not properly pause")
# Verify request completed
assert final_output.finished, "Request should have finished"
assert len(final_output.outputs[0].token_ids) == 30, "Should have all tokens"
engine.shutdown()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from contextlib import ExitStack
from unittest.mock import MagicMock
@@ -661,3 +662,301 @@ async def collect_outputs(
outputs_list.append(output)
final_output = output
return final_output
# =============================================================================
# Pause/Resume Tests
# =============================================================================
@pytest.mark.asyncio
async def test_pause_resume_basic():
"""Test basic pause/resume flag behavior and idempotency.
Tests:
- pause_generation sets the paused flag
- resume_generation clears the paused flag
- calling pause when already paused is a no-op
- calling resume when not paused is safe
- all pause modes work with no requests in flight
- rapid pause/resume cycles don't break the engine
"""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Initially not paused
assert not await engine.is_paused()
# Resume when not paused should be safe
await engine.resume_generation()
assert not await engine.is_paused()
# Pause sets flag
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Pause when already paused is a no-op
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Resume clears flag
await engine.resume_generation()
assert not await engine.is_paused()
# Test all modes with no requests in flight
for mode in ("abort", "wait", "keep"):
await engine.pause_generation(mode=mode)
# "keep" only freezes the scheduler; it does not set _paused
if mode != "keep":
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Concurrent pause/resume race conditions - should not deadlock or raise
await asyncio.gather(
engine.pause_generation(mode="abort"),
engine.resume_generation(),
engine.pause_generation(mode="abort"),
engine.resume_generation(),
)
# Ensure we end in a known state
await engine.resume_generation()
assert not await engine.is_paused()
# Engine should still work after all cycles
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="post-cycles",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_pause_abort():
"""Test that mode='abort' aborts in-flight requests immediately."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a long-running request
sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
outputs: list[RequestOutput] = []
async def gen():
async for out in engine.generate(
request_id="test-abort-pause",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(out)
return outputs[-1] if outputs else None
# Start generation task
gen_task = asyncio.create_task(gen())
# Wait for some tokens to be generated
while len(outputs) < 3:
await asyncio.sleep(0.01)
# Pause with abort mode
await engine.pause_generation(mode="abort")
# Wait for task to complete (should be aborted)
final_output = await gen_task
# Request should be finished (aborted)
assert final_output is not None
assert final_output.finished
assert final_output.outputs[0].finish_reason == "abort"
# Also test that new requests are blocked while paused, then resume
assert await engine.is_paused()
request_completed = False
async def gen_blocked():
nonlocal request_completed
async for out in engine.generate(
request_id="test-blocked",
prompt=TEXT_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_completed = True
return out
# Start a request (should block)
gen_task2 = asyncio.create_task(gen_blocked())
# Wait a bit - request should not have completed
await asyncio.sleep(0.3)
assert not request_completed, "Request should be blocked while paused"
# Resume
await engine.resume_generation()
# Now request should complete
final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
assert request_completed
assert final_output2.finished
@pytest.mark.asyncio
async def test_pause_wait():
"""Test that mode='wait' waits for in-flight requests to complete."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a request - use fewer tokens since wait mode waits for completion
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
got_first_token = asyncio.Event()
request_completed = False
async def gen():
nonlocal request_completed
async for out in engine.generate(
request_id="test-wait",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
got_first_token.set()
request_completed = True
return out
# Start generation
gen_task = asyncio.create_task(gen())
# Wait for generation to start (event-driven)
await asyncio.wait_for(got_first_token.wait(), timeout=30.0)
# Pause with wait mode - should wait for request to finish
await engine.pause_generation(mode="wait")
# By now the request should be done (wait mode waits for completion)
assert request_completed, "Request should have completed during wait"
final_output = gen_task.result()
assert final_output.finished
# Should complete normally, not aborted
assert final_output.outputs[0].finish_reason != "eos"
@pytest.mark.asyncio
async def test_pause_keep_single_request():
"""Test that mode='keep' freezes a single request and resumes with timing gap."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
token_times: list[tuple[int, float]] = []
pause_duration = 5.0
pause_token_idx = 0
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-keep-single",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
return output
async def controller_task():
"""Pause and resume the engine."""
nonlocal pause_token_idx
# Wait for some tokens (event-driven, handles slow token generation)
while len(token_times) < 5:
await asyncio.sleep(0.01)
# Pause with keep mode
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
# Sleep while paused
await asyncio.sleep(pause_duration)
# Resume
await engine.resume_generation()
# Run both tasks with timeout for slow generation
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.wait_for(
asyncio.gather(gen_task, ctrl_task), timeout=60.0
)
# Request should complete with all tokens
assert final_output.finished
assert len(final_output.outputs[0].token_ids) == 30
# Check the gap at the recorded pause index matches the pause duration
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)
@pytest.mark.asyncio
async def test_pause_keep_multi_request():
"""Test that mode='keep' freezes multiple concurrent requests and all resume."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
num_requests = 3
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
completed_requests: list[str] = []
any_token_generated = asyncio.Event()
async def gen_multi(request_id: str):
async for out in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
any_token_generated.set()
completed_requests.append(request_id)
return out
# Start multiple requests
tasks = [
asyncio.create_task(gen_multi(f"req-multi-{i}"))
for i in range(num_requests)
]
# Wait for at least one token across any request (event-driven)
await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)
# Pause with keep mode
await engine.pause_generation(mode="keep")
# Wait while paused
await asyncio.sleep(0.5)
# Resume
await engine.resume_generation()
# All requests should complete
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)
assert len(completed_requests) == num_requests
for result in results:
assert result.finished
assert len(result.outputs[0].token_ids) == 10

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any
from typing import TYPE_CHECKING, Any
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.weight_transfer.base import (
@@ -22,6 +22,9 @@ from vllm.tasks import SupportedTask
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
if TYPE_CHECKING:
from vllm.v1.engine import PauseMode
class EngineClient(ABC):
"""Protocol class for Clients to Engine"""
@@ -158,16 +161,22 @@ class EngineClient(ABC):
async def pause_generation(
self,
*,
mode: "PauseMode" = "abort",
wait_for_inflight_requests: bool = False,
clear_cache: bool = True,
) -> None:
"""Pause new generation/encoding requests.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight requests
to finish before pausing. When ``False`` (default), aborts in-flight
requests immediately.
clear_cache: Whether to clear KV and prefix caches after draining.
mode: How to handle in-flight requests:
- ``"abort"``: Abort all in-flight requests immediately
and return partial results with "abort" reason (default).
- ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
clear_cache: DEPRECATED. Whether to clear KV and prefix caches
after draining.
"""
...

View File

@@ -3,6 +3,7 @@
import json
from http import HTTPStatus
from typing import Annotated
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse
@@ -14,6 +15,7 @@ from vllm.distributed.weight_transfer.base import (
)
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
from vllm.v1.engine import PauseMode
logger = init_logger(__name__)
@@ -28,24 +30,29 @@ router = APIRouter()
@router.post("/pause")
async def pause_generation(
raw_request: Request,
mode: Annotated[PauseMode, Query()] = "abort",
wait_for_inflight_requests: bool = Query(False),
clear_cache: bool = Query(True),
clear_cache: Annotated[bool, Query()] = True,
) -> JSONResponse:
"""Pause generation requests to allow weight updates.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
aborts any in-flight requests immediately.
clear_cache: Whether to clear KV/prefix caches after draining.
mode: How to handle in-flight requests:
- ``"abort"``: Abort all in-flight requests immediately (default).
- ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on /resume.
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
draining. Ignored when mode="keep".
"""
engine = engine_client(raw_request)
try:
await engine.pause_generation(
wait_for_inflight_requests=wait_for_inflight_requests,
mode=mode,
clear_cache=clear_cache,
wait_for_inflight_requests=wait_for_inflight_requests,
)
return JSONResponse(
content={"status": "paused"},

View File

@@ -4,7 +4,7 @@
import enum
import time
from collections.abc import Mapping
from typing import Any
from typing import Any, Literal
import msgspec
import numpy as np
@@ -18,6 +18,12 @@ from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
from vllm.v1.serial_utils import UtilityResult
# Type for pause_generation mode parameter.
# - "abort": Abort all in-flight requests immediately (default).
# - "wait": Wait for in-flight requests to complete before pausing.
# - "keep": Freeze requests in queue; they resume on resume_generation().
PauseMode = Literal["abort", "wait", "keep"]
# These are possible values of RequestOutput.finish_reason,
# so form part of the external API.
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")

View File

@@ -38,7 +38,7 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_va
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine import EngineCoreRequest, PauseMode
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
from vllm.v1.engine.input_processor import InputProcessor
@@ -170,6 +170,7 @@ class AsyncLLM(EngineClient):
# Pause / resume state for async RL workflows.
self._pause_cond = asyncio.Condition()
self._paused = False
self._client_count = client_count
self.output_handler: asyncio.Task | None = None
try:
@@ -728,7 +729,8 @@ class AsyncLLM(EngineClient):
async def pause_generation(
self,
*,
wait_for_inflight_requests: bool = False,
mode: PauseMode = "abort",
wait_for_inflight_requests: bool | None = None,
clear_cache: bool = True,
) -> None:
"""
@@ -737,27 +739,52 @@ class AsyncLLM(EngineClient):
New generation/encoding requests are blocked until resume.
Args:
wait_for_inflight_requests: When ``True`` waits for in-flight
requests to finish before pausing. When ``False`` (default),
immediately aborts any in-flight requests.
mode: How to handle in-flight requests:
- ``"abort"``: Abort all in-flight requests immediately
(default).
- ``"wait"``: Wait for in-flight requests to complete.
- ``"keep"``: Freeze requests in queue; they resume on
:meth:`resume_generation`.
wait_for_inflight_requests: DEPRECATED: use mode argument.
Whether to wait for in-flight requests to complete before pausing.
clear_cache: Whether to clear KV cache and prefix cache after
draining. Set to ``False`` to preserve cache for faster resume.
Default is ``True`` (clear caches).
"""
if wait_for_inflight_requests:
warnings.warn(
"The `wait_for_inflight_requests` parameter in "
"`AsyncLLM.pause_generation()` is deprecated. "
"Please use `mode` argument instead.",
DeprecationWarning,
stacklevel=2,
)
mode = "wait"
async with self._pause_cond:
if self._paused:
return
self._paused = True
if mode == "keep":
# Freeze requests in the scheduler - they will resume on
# resume_generation().
await self.engine_core.pause_scheduler_async()
else:
if self._client_count > 1:
raise NotImplementedError(
"pause_generation is not supported with --api-server-count > 1"
" when mode is not 'keep'"
)
async with self._pause_cond:
if not self._paused:
self._paused = True
if not wait_for_inflight_requests:
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
# Wait for running requests to drain before clearing cache.
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
if mode == "abort":
request_ids = list(self.output_processor.request_states.keys())
if request_ids:
await self.abort(request_ids, internal=True)
elif mode == "wait":
if self.output_processor.has_unfinished_requests():
await self.output_processor.wait_for_requests_to_drain()
else:
raise ValueError(f"Invalid mode: {mode}")
# Clear cache
if clear_cache:
@@ -769,6 +796,7 @@ class AsyncLLM(EngineClient):
"""Resume generation after :meth:`pause_generation`."""
async with self._pause_cond:
await self.engine_core.resume_scheduler_async()
self._paused = False
self._pause_cond.notify_all() # Wake up all waiting requests

View File

@@ -209,6 +209,10 @@ class EngineCore:
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
self.aborts_queue = queue.Queue[list[str]]()
# Pause state for "keep" mode - freezes requests in queue.
self._scheduler_paused = False
# Mark the startup heap as static so that it's ignored by GC.
# Reduces pause times of oldest generation collections.
freeze_gc_heap()
@@ -322,6 +326,20 @@ class EngineCore:
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
def pause_scheduler(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Requests are kept frozen in queue and can be resumed later.
"""
self._scheduler_paused = True
def resume_scheduler(self) -> None:
"""Resume the scheduler after a pause.
Resumes processing of frozen requests in the queue.
"""
self._scheduler_paused = False
@contextmanager
def log_error_detail(self, scheduler_output: SchedulerOutput):
"""Execute the model and log detailed info on failure."""
@@ -375,6 +393,10 @@ class EngineCore:
was executed.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
@@ -425,6 +447,10 @@ class EngineCore:
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
# If paused, don't schedule any work.
if self._scheduler_paused:
return {}, False
batch_queue = self.batch_queue
assert batch_queue is not None
@@ -1007,6 +1033,7 @@ class EngineCoreProc(EngineCore):
not self.engines_running
and not self.scheduler.has_requests()
and not self.batch_queue
and not self._scheduler_paused
):
if self.input_queue.empty():
# Drain aborts queue; all aborts are also processed via input_queue.

View File

@@ -105,7 +105,7 @@ class EngineCoreClient(ABC):
client_addresses: dict[str, str] | None = None,
client_count: int = 1,
client_index: int = 0,
) -> "MPClient":
) -> "AsyncMPClient":
parallel_config = vllm_config.parallel_config
client_args = (
vllm_config,
@@ -976,6 +976,16 @@ class AsyncMPClient(MPClient):
if request_ids and not self.resources.engine_dead:
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue.
Blocks until the EngineCore acknowledges the pause.
"""
await self.call_utility_async("pause_scheduler")
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
await self.call_utility_async("resume_scheduler")
async def profile_async(self, is_start: bool = True) -> None:
await self.call_utility_async("profile", is_start)
@@ -1188,6 +1198,18 @@ class DPAsyncMPClient(AsyncMPClient):
def get_core_engine_for_request(self, request: EngineCoreRequest):
return self.core_engine
async def pause_scheduler_async(self) -> None:
"""Pause the scheduler, keeping requests frozen in queue."""
raise NotImplementedError(
"pause_scheduler_async is not yet supported for data parallel"
)
async def resume_scheduler_async(self) -> None:
"""Resume the scheduler after a pause."""
raise NotImplementedError(
"resume_scheduler_async is not yet supported for data parallel"
)
class DPLBAsyncMPClient(DPAsyncMPClient):
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)