[Feat][RL] Pause and Resume with keep requests for single engine (#32351)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
108
examples/offline_inference/pause_resume.py
Normal file
108
examples/offline_inference/pause_resume.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test for pause/resume with keep mode.
|
||||
|
||||
This test uses concurrent tasks to verify the engine truly stops generating
|
||||
during pause:
|
||||
1. Generator task: continuously generates and logs time between tokens
|
||||
2. Controller task: sends pause/resume commands
|
||||
|
||||
If the engine properly pauses, we should see a gap in token timestamps
|
||||
matching the pause duration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
PAUSE_DURATION = 3.0 # seconds
|
||||
|
||||
|
||||
async def main():
|
||||
# Create engine with a small model
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
|
||||
prompt = "Write a story about a dragon. Once upon a time"
|
||||
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
|
||||
|
||||
# Track token arrival times
|
||||
token_times: list[tuple[int, float]] = [] # (token_count, timestamp)
|
||||
pause_time: float = 0
|
||||
resume_time: float = 0
|
||||
pause_token_idx: int = 0 # Index in token_times when pause occurred
|
||||
|
||||
async def generator_task():
|
||||
"""Generate tokens and record timestamps."""
|
||||
async for output in engine.generate(
|
||||
request_id="test-req",
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
token_count = len(output.outputs[0].token_ids)
|
||||
token_times.append((token_count, time.monotonic()))
|
||||
print(
|
||||
f"Token {token_count} arrived:"
|
||||
f"T={token_times[-1][1] - token_times[0][1]:.3f}s"
|
||||
)
|
||||
return output
|
||||
|
||||
async def controller_task():
|
||||
"""Pause and resume the engine after some tokens generated."""
|
||||
nonlocal pause_time, resume_time, pause_token_idx
|
||||
|
||||
# Wait for some tokens to be generated
|
||||
while len(token_times) < 5:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
print(f"\nPausing engine (keep mode) at token {len(token_times)}")
|
||||
pause_time = time.monotonic()
|
||||
await engine.pause_generation(mode="keep")
|
||||
pause_token_idx = len(token_times)
|
||||
print(f"Paused! Sleeping for {PAUSE_DURATION}s...")
|
||||
|
||||
# Sleep while paused - no tokens should be generated during this time
|
||||
await asyncio.sleep(PAUSE_DURATION)
|
||||
|
||||
print("Resuming engine...")
|
||||
await engine.resume_generation()
|
||||
resume_time = time.monotonic()
|
||||
print("Resumed!\n")
|
||||
|
||||
# Run both tasks concurrently
|
||||
gen_task = asyncio.create_task(generator_task())
|
||||
ctrl_task = asyncio.create_task(controller_task())
|
||||
|
||||
final_output, _ = await asyncio.gather(gen_task, ctrl_task)
|
||||
|
||||
# Verify the pause actually stopped generation.
|
||||
# The gap after the pause token should be approximately the sleep duration.
|
||||
pause_gap = token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
|
||||
print(
|
||||
f"\nGap after pause (token {pause_token_idx - 1} -> {pause_token_idx}): "
|
||||
f"{pause_gap:.3f}s"
|
||||
)
|
||||
if pause_gap >= PAUSE_DURATION * 0.9:
|
||||
print(f"✓ Test passed! Engine paused for ~{pause_gap:.1f}s")
|
||||
else:
|
||||
print(
|
||||
f"✗ Test failed! Expected ~{PAUSE_DURATION}s gap after pause, "
|
||||
f"got {pause_gap:.3f}s"
|
||||
)
|
||||
raise AssertionError("Engine did not properly pause")
|
||||
|
||||
# Verify request completed
|
||||
assert final_output.finished, "Request should have finished"
|
||||
assert len(final_output.outputs[0].token_ids) == 30, "Should have all tokens"
|
||||
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -661,3 +662,301 @@ async def collect_outputs(
|
||||
outputs_list.append(output)
|
||||
final_output = output
|
||||
return final_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pause/Resume Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_basic():
|
||||
"""Test basic pause/resume flag behavior and idempotency.
|
||||
|
||||
Tests:
|
||||
- pause_generation sets the paused flag
|
||||
- resume_generation clears the paused flag
|
||||
- calling pause when already paused is a no-op
|
||||
- calling resume when not paused is safe
|
||||
- all pause modes work with no requests in flight
|
||||
- rapid pause/resume cycles don't break the engine
|
||||
"""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Initially not paused
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Resume when not paused should be safe
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Pause sets flag
|
||||
await engine.pause_generation(mode="abort")
|
||||
assert await engine.is_paused()
|
||||
|
||||
# Pause when already paused is a no-op
|
||||
await engine.pause_generation(mode="abort")
|
||||
assert await engine.is_paused()
|
||||
|
||||
# Resume clears flag
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Test all modes with no requests in flight
|
||||
for mode in ("abort", "wait", "keep"):
|
||||
await engine.pause_generation(mode=mode)
|
||||
# "keep" only freezes the scheduler; it does not set _paused
|
||||
if mode != "keep":
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Concurrent pause/resume race conditions - should not deadlock or raise
|
||||
await asyncio.gather(
|
||||
engine.pause_generation(mode="abort"),
|
||||
engine.resume_generation(),
|
||||
engine.pause_generation(mode="abort"),
|
||||
engine.resume_generation(),
|
||||
)
|
||||
|
||||
# Ensure we end in a known state
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Engine should still work after all cycles
|
||||
sampling_params = SamplingParams(max_tokens=5)
|
||||
async for out in engine.generate(
|
||||
request_id="post-cycles",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
pass
|
||||
assert out.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_abort():
|
||||
"""Test that mode='abort' aborts in-flight requests immediately."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Start a long-running request
|
||||
sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
|
||||
outputs: list[RequestOutput] = []
|
||||
|
||||
async def gen():
|
||||
async for out in engine.generate(
|
||||
request_id="test-abort-pause",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
outputs.append(out)
|
||||
return outputs[-1] if outputs else None
|
||||
|
||||
# Start generation task
|
||||
gen_task = asyncio.create_task(gen())
|
||||
|
||||
# Wait for some tokens to be generated
|
||||
while len(outputs) < 3:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Pause with abort mode
|
||||
await engine.pause_generation(mode="abort")
|
||||
|
||||
# Wait for task to complete (should be aborted)
|
||||
final_output = await gen_task
|
||||
|
||||
# Request should be finished (aborted)
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
assert final_output.outputs[0].finish_reason == "abort"
|
||||
|
||||
# Also test that new requests are blocked while paused, then resume
|
||||
assert await engine.is_paused()
|
||||
|
||||
request_completed = False
|
||||
|
||||
async def gen_blocked():
|
||||
nonlocal request_completed
|
||||
async for out in engine.generate(
|
||||
request_id="test-blocked",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
):
|
||||
pass
|
||||
request_completed = True
|
||||
return out
|
||||
|
||||
# Start a request (should block)
|
||||
gen_task2 = asyncio.create_task(gen_blocked())
|
||||
|
||||
# Wait a bit - request should not have completed
|
||||
await asyncio.sleep(0.3)
|
||||
assert not request_completed, "Request should be blocked while paused"
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# Now request should complete
|
||||
final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
|
||||
assert request_completed
|
||||
assert final_output2.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_wait():
|
||||
"""Test that mode='wait' waits for in-flight requests to complete."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Start a request - use fewer tokens since wait mode waits for completion
|
||||
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
|
||||
got_first_token = asyncio.Event()
|
||||
request_completed = False
|
||||
|
||||
async def gen():
|
||||
nonlocal request_completed
|
||||
async for out in engine.generate(
|
||||
request_id="test-wait",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
got_first_token.set()
|
||||
request_completed = True
|
||||
return out
|
||||
|
||||
# Start generation
|
||||
gen_task = asyncio.create_task(gen())
|
||||
|
||||
# Wait for generation to start (event-driven)
|
||||
await asyncio.wait_for(got_first_token.wait(), timeout=30.0)
|
||||
|
||||
# Pause with wait mode - should wait for request to finish
|
||||
await engine.pause_generation(mode="wait")
|
||||
|
||||
# By now the request should be done (wait mode waits for completion)
|
||||
assert request_completed, "Request should have completed during wait"
|
||||
|
||||
final_output = gen_task.result()
|
||||
assert final_output.finished
|
||||
# Should complete normally, not aborted
|
||||
assert final_output.outputs[0].finish_reason != "eos"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_keep_single_request():
|
||||
"""Test that mode='keep' freezes a single request and resumes with timing gap."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
|
||||
token_times: list[tuple[int, float]] = []
|
||||
pause_duration = 5.0
|
||||
pause_token_idx = 0
|
||||
|
||||
async def generator_task():
|
||||
"""Generate tokens and record timestamps."""
|
||||
async for output in engine.generate(
|
||||
request_id="test-keep-single",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
token_count = len(output.outputs[0].token_ids)
|
||||
token_times.append((token_count, time.monotonic()))
|
||||
return output
|
||||
|
||||
async def controller_task():
|
||||
"""Pause and resume the engine."""
|
||||
nonlocal pause_token_idx
|
||||
# Wait for some tokens (event-driven, handles slow token generation)
|
||||
while len(token_times) < 5:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Pause with keep mode
|
||||
await engine.pause_generation(mode="keep")
|
||||
pause_token_idx = len(token_times)
|
||||
|
||||
# Sleep while paused
|
||||
await asyncio.sleep(pause_duration)
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# Run both tasks with timeout for slow generation
|
||||
gen_task = asyncio.create_task(generator_task())
|
||||
ctrl_task = asyncio.create_task(controller_task())
|
||||
|
||||
final_output, _ = await asyncio.wait_for(
|
||||
asyncio.gather(gen_task, ctrl_task), timeout=60.0
|
||||
)
|
||||
|
||||
# Request should complete with all tokens
|
||||
assert final_output.finished
|
||||
assert len(final_output.outputs[0].token_ids) == 30
|
||||
|
||||
# Check the gap at the recorded pause index matches the pause duration
|
||||
pause_gap = (
|
||||
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
|
||||
)
|
||||
assert pause_gap >= pause_duration * 0.8, (
|
||||
f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_keep_multi_request():
|
||||
"""Test that mode='keep' freezes multiple concurrent requests and all resume."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
num_requests = 3
|
||||
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
|
||||
completed_requests: list[str] = []
|
||||
any_token_generated = asyncio.Event()
|
||||
|
||||
async def gen_multi(request_id: str):
|
||||
async for out in engine.generate(
|
||||
request_id=request_id,
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
any_token_generated.set()
|
||||
completed_requests.append(request_id)
|
||||
return out
|
||||
|
||||
# Start multiple requests
|
||||
tasks = [
|
||||
asyncio.create_task(gen_multi(f"req-multi-{i}"))
|
||||
for i in range(num_requests)
|
||||
]
|
||||
|
||||
# Wait for at least one token across any request (event-driven)
|
||||
await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)
|
||||
|
||||
# Pause with keep mode
|
||||
await engine.pause_generation(mode="keep")
|
||||
|
||||
# Wait while paused
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# All requests should complete
|
||||
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)
|
||||
|
||||
assert len(completed_requests) == num_requests
|
||||
for result in results:
|
||||
assert result.finished
|
||||
assert len(result.outputs[0].token_ids) == 10
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.weight_transfer.base import (
|
||||
@@ -22,6 +22,9 @@ from vllm.tasks import SupportedTask
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.engine import PauseMode
|
||||
|
||||
|
||||
class EngineClient(ABC):
|
||||
"""Protocol class for Clients to Engine"""
|
||||
@@ -158,16 +161,22 @@ class EngineClient(ABC):
|
||||
async def pause_generation(
|
||||
self,
|
||||
*,
|
||||
mode: "PauseMode" = "abort",
|
||||
wait_for_inflight_requests: bool = False,
|
||||
clear_cache: bool = True,
|
||||
) -> None:
|
||||
"""Pause new generation/encoding requests.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight requests
|
||||
to finish before pausing. When ``False`` (default), aborts in-flight
|
||||
requests immediately.
|
||||
clear_cache: Whether to clear KV and prefix caches after draining.
|
||||
mode: How to handle in-flight requests:
|
||||
- ``"abort"``: Abort all in-flight requests immediately
|
||||
and return partial results with "abort" reason (default).
|
||||
- ``"wait"``: Wait for in-flight requests to complete.
|
||||
- ``"keep"``: Freeze requests in queue; they resume on
|
||||
:meth:`resume_generation`.
|
||||
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
|
||||
clear_cache: DEPRECATED. Whether to clear KV and prefix caches
|
||||
after draining.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -14,6 +15,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.engine import PauseMode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -28,24 +30,29 @@ router = APIRouter()
|
||||
@router.post("/pause")
|
||||
async def pause_generation(
|
||||
raw_request: Request,
|
||||
mode: Annotated[PauseMode, Query()] = "abort",
|
||||
wait_for_inflight_requests: bool = Query(False),
|
||||
clear_cache: bool = Query(True),
|
||||
clear_cache: Annotated[bool, Query()] = True,
|
||||
) -> JSONResponse:
|
||||
"""Pause generation requests to allow weight updates.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
aborts any in-flight requests immediately.
|
||||
clear_cache: Whether to clear KV/prefix caches after draining.
|
||||
mode: How to handle in-flight requests:
|
||||
- ``"abort"``: Abort all in-flight requests immediately (default).
|
||||
- ``"wait"``: Wait for in-flight requests to complete.
|
||||
- ``"keep"``: Freeze requests in queue; they resume on /resume.
|
||||
wait_for_inflight_requests: DEPRECATED. Use ``mode="wait"`` instead.
|
||||
clear_cache: DEPRECATED. Whether to clear KV/prefix caches after
|
||||
draining. Ignored when mode="keep".
|
||||
"""
|
||||
|
||||
engine = engine_client(raw_request)
|
||||
|
||||
try:
|
||||
await engine.pause_generation(
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
mode=mode,
|
||||
clear_cache=clear_cache,
|
||||
wait_for_inflight_requests=wait_for_inflight_requests,
|
||||
)
|
||||
return JSONResponse(
|
||||
content={"status": "paused"},
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@@ -18,6 +18,12 @@ from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# Type for pause_generation mode parameter.
|
||||
# - "abort": Abort all in-flight requests immediately (default).
|
||||
# - "wait": Wait for in-flight requests to complete before pausing.
|
||||
# - "keep": Freeze requests in queue; they resume on resume_generation().
|
||||
PauseMode = Literal["abort", "wait", "keep"]
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort", "error")
|
||||
|
||||
@@ -38,7 +38,7 @@ from vllm.transformers_utils.config import maybe_register_config_serialize_by_va
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.async_utils import cancel_task_threadsafe
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine import EngineCoreRequest, PauseMode
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
@@ -170,6 +170,7 @@ class AsyncLLM(EngineClient):
|
||||
# Pause / resume state for async RL workflows.
|
||||
self._pause_cond = asyncio.Condition()
|
||||
self._paused = False
|
||||
self._client_count = client_count
|
||||
|
||||
self.output_handler: asyncio.Task | None = None
|
||||
try:
|
||||
@@ -728,7 +729,8 @@ class AsyncLLM(EngineClient):
|
||||
async def pause_generation(
|
||||
self,
|
||||
*,
|
||||
wait_for_inflight_requests: bool = False,
|
||||
mode: PauseMode = "abort",
|
||||
wait_for_inflight_requests: bool | None = None,
|
||||
clear_cache: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -737,27 +739,52 @@ class AsyncLLM(EngineClient):
|
||||
New generation/encoding requests are blocked until resume.
|
||||
|
||||
Args:
|
||||
wait_for_inflight_requests: When ``True`` waits for in-flight
|
||||
requests to finish before pausing. When ``False`` (default),
|
||||
immediately aborts any in-flight requests.
|
||||
mode: How to handle in-flight requests:
|
||||
- ``"abort"``: Abort all in-flight requests immediately
|
||||
(default).
|
||||
- ``"wait"``: Wait for in-flight requests to complete.
|
||||
- ``"keep"``: Freeze requests in queue; they resume on
|
||||
:meth:`resume_generation`.
|
||||
wait_for_inflight_requests: DEPRECATED: use mode argument.
|
||||
Whether to wait for in-flight requests to complete before pausing.
|
||||
clear_cache: Whether to clear KV cache and prefix cache after
|
||||
draining. Set to ``False`` to preserve cache for faster resume.
|
||||
Default is ``True`` (clear caches).
|
||||
|
||||
"""
|
||||
if wait_for_inflight_requests:
|
||||
warnings.warn(
|
||||
"The `wait_for_inflight_requests` parameter in "
|
||||
"`AsyncLLM.pause_generation()` is deprecated. "
|
||||
"Please use `mode` argument instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
mode = "wait"
|
||||
|
||||
async with self._pause_cond:
|
||||
if self._paused:
|
||||
return
|
||||
self._paused = True
|
||||
if mode == "keep":
|
||||
# Freeze requests in the scheduler - they will resume on
|
||||
# resume_generation().
|
||||
await self.engine_core.pause_scheduler_async()
|
||||
else:
|
||||
if self._client_count > 1:
|
||||
raise NotImplementedError(
|
||||
"pause_generation is not supported with --api-server-count > 1"
|
||||
" when mode is not 'keep'"
|
||||
)
|
||||
async with self._pause_cond:
|
||||
if not self._paused:
|
||||
self._paused = True
|
||||
|
||||
if not wait_for_inflight_requests:
|
||||
request_ids = list(self.output_processor.request_states.keys())
|
||||
if request_ids:
|
||||
await self.abort(request_ids, internal=True)
|
||||
|
||||
# Wait for running requests to drain before clearing cache.
|
||||
if self.output_processor.has_unfinished_requests():
|
||||
await self.output_processor.wait_for_requests_to_drain()
|
||||
if mode == "abort":
|
||||
request_ids = list(self.output_processor.request_states.keys())
|
||||
if request_ids:
|
||||
await self.abort(request_ids, internal=True)
|
||||
elif mode == "wait":
|
||||
if self.output_processor.has_unfinished_requests():
|
||||
await self.output_processor.wait_for_requests_to_drain()
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
|
||||
# Clear cache
|
||||
if clear_cache:
|
||||
@@ -769,6 +796,7 @@ class AsyncLLM(EngineClient):
|
||||
"""Resume generation after :meth:`pause_generation`."""
|
||||
|
||||
async with self._pause_cond:
|
||||
await self.engine_core.resume_scheduler_async()
|
||||
self._paused = False
|
||||
self._pause_cond.notify_all() # Wake up all waiting requests
|
||||
|
||||
|
||||
@@ -209,6 +209,10 @@ class EngineCore:
|
||||
self.async_scheduling = vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.aborts_queue = queue.Queue[list[str]]()
|
||||
|
||||
# Pause state for "keep" mode - freezes requests in queue.
|
||||
self._scheduler_paused = False
|
||||
|
||||
# Mark the startup heap as static so that it's ignored by GC.
|
||||
# Reduces pause times of oldest generation collections.
|
||||
freeze_gc_heap()
|
||||
@@ -322,6 +326,20 @@ class EngineCore:
|
||||
# (i.e. client-aborted vs stop criteria met).
|
||||
self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED)
|
||||
|
||||
def pause_scheduler(self) -> None:
|
||||
"""Pause the scheduler, keeping requests frozen in queue.
|
||||
|
||||
Requests are kept frozen in queue and can be resumed later.
|
||||
"""
|
||||
self._scheduler_paused = True
|
||||
|
||||
def resume_scheduler(self) -> None:
|
||||
"""Resume the scheduler after a pause.
|
||||
|
||||
Resumes processing of frozen requests in the queue.
|
||||
"""
|
||||
self._scheduler_paused = False
|
||||
|
||||
@contextmanager
|
||||
def log_error_detail(self, scheduler_output: SchedulerOutput):
|
||||
"""Execute the model and log detailed info on failure."""
|
||||
@@ -375,6 +393,10 @@ class EngineCore:
|
||||
was executed.
|
||||
"""
|
||||
|
||||
# If paused, don't schedule any work.
|
||||
if self._scheduler_paused:
|
||||
return {}, False
|
||||
|
||||
# Check for any requests remaining in the scheduler - unfinished,
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
@@ -425,6 +447,10 @@ class EngineCore:
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
# If paused, don't schedule any work.
|
||||
if self._scheduler_paused:
|
||||
return {}, False
|
||||
|
||||
batch_queue = self.batch_queue
|
||||
assert batch_queue is not None
|
||||
|
||||
@@ -1007,6 +1033,7 @@ class EngineCoreProc(EngineCore):
|
||||
not self.engines_running
|
||||
and not self.scheduler.has_requests()
|
||||
and not self.batch_queue
|
||||
and not self._scheduler_paused
|
||||
):
|
||||
if self.input_queue.empty():
|
||||
# Drain aborts queue; all aborts are also processed via input_queue.
|
||||
|
||||
@@ -105,7 +105,7 @@ class EngineCoreClient(ABC):
|
||||
client_addresses: dict[str, str] | None = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> "MPClient":
|
||||
) -> "AsyncMPClient":
|
||||
parallel_config = vllm_config.parallel_config
|
||||
client_args = (
|
||||
vllm_config,
|
||||
@@ -976,6 +976,16 @@ class AsyncMPClient(MPClient):
|
||||
if request_ids and not self.resources.engine_dead:
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
async def pause_scheduler_async(self) -> None:
|
||||
"""Pause the scheduler, keeping requests frozen in queue.
|
||||
Blocks until the EngineCore acknowledges the pause.
|
||||
"""
|
||||
await self.call_utility_async("pause_scheduler")
|
||||
|
||||
async def resume_scheduler_async(self) -> None:
|
||||
"""Resume the scheduler after a pause."""
|
||||
await self.call_utility_async("resume_scheduler")
|
||||
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self.call_utility_async("profile", is_start)
|
||||
|
||||
@@ -1188,6 +1198,18 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
def get_core_engine_for_request(self, request: EngineCoreRequest):
|
||||
return self.core_engine
|
||||
|
||||
async def pause_scheduler_async(self) -> None:
|
||||
"""Pause the scheduler, keeping requests frozen in queue."""
|
||||
raise NotImplementedError(
|
||||
"pause_scheduler_async is not yet supported for data parallel"
|
||||
)
|
||||
|
||||
async def resume_scheduler_async(self) -> None:
|
||||
"""Resume the scheduler after a pause."""
|
||||
raise NotImplementedError(
|
||||
"resume_scheduler_async is not yet supported for data parallel"
|
||||
)
|
||||
|
||||
|
||||
class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
||||
|
||||
Reference in New Issue
Block a user