[Feat][RL] Pause and Resume with keep requests for single engine (#32351)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Aaron Hao
2026-02-06 16:08:58 -08:00
committed by GitHub
parent 4a2d00eafd
commit 89a385d79f
8 changed files with 536 additions and 30 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from contextlib import ExitStack
from unittest.mock import MagicMock
@@ -661,3 +662,301 @@ async def collect_outputs(
outputs_list.append(output)
final_output = output
return final_output
# =============================================================================
# Pause/Resume Tests
# =============================================================================
@pytest.mark.asyncio
async def test_pause_resume_basic():
"""Test basic pause/resume flag behavior and idempotency.
Tests:
- pause_generation sets the paused flag
- resume_generation clears the paused flag
- calling pause when already paused is a no-op
- calling resume when not paused is safe
- all pause modes work with no requests in flight
- rapid pause/resume cycles don't break the engine
"""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Initially not paused
assert not await engine.is_paused()
# Resume when not paused should be safe
await engine.resume_generation()
assert not await engine.is_paused()
# Pause sets flag
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Pause when already paused is a no-op
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
# Resume clears flag
await engine.resume_generation()
assert not await engine.is_paused()
# Test all modes with no requests in flight
for mode in ("abort", "wait", "keep"):
await engine.pause_generation(mode=mode)
# "keep" only freezes the scheduler; it does not set _paused
if mode != "keep":
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Concurrent pause/resume race conditions - should not deadlock or raise
await asyncio.gather(
engine.pause_generation(mode="abort"),
engine.resume_generation(),
engine.pause_generation(mode="abort"),
engine.resume_generation(),
)
# Ensure we end in a known state
await engine.resume_generation()
assert not await engine.is_paused()
# Engine should still work after all cycles
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="post-cycles",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_pause_abort():
"""Test that mode='abort' aborts in-flight requests immediately."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a long-running request
sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
outputs: list[RequestOutput] = []
async def gen():
async for out in engine.generate(
request_id="test-abort-pause",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
outputs.append(out)
return outputs[-1] if outputs else None
# Start generation task
gen_task = asyncio.create_task(gen())
# Wait for some tokens to be generated
while len(outputs) < 3:
await asyncio.sleep(0.01)
# Pause with abort mode
await engine.pause_generation(mode="abort")
# Wait for task to complete (should be aborted)
final_output = await gen_task
# Request should be finished (aborted)
assert final_output is not None
assert final_output.finished
assert final_output.outputs[0].finish_reason == "abort"
# Also test that new requests are blocked while paused, then resume
assert await engine.is_paused()
request_completed = False
async def gen_blocked():
nonlocal request_completed
async for out in engine.generate(
request_id="test-blocked",
prompt=TEXT_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_completed = True
return out
# Start a request (should block)
gen_task2 = asyncio.create_task(gen_blocked())
# Wait a bit - request should not have completed
await asyncio.sleep(0.3)
assert not request_completed, "Request should be blocked while paused"
# Resume
await engine.resume_generation()
# Now request should complete
final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
assert request_completed
assert final_output2.finished
@pytest.mark.asyncio
async def test_pause_wait():
"""Test that mode='wait' waits for in-flight requests to complete."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
# Start a request - use fewer tokens since wait mode waits for completion
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
got_first_token = asyncio.Event()
request_completed = False
async def gen():
nonlocal request_completed
async for out in engine.generate(
request_id="test-wait",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
got_first_token.set()
request_completed = True
return out
# Start generation
gen_task = asyncio.create_task(gen())
# Wait for generation to start (event-driven)
await asyncio.wait_for(got_first_token.wait(), timeout=30.0)
# Pause with wait mode - should wait for request to finish
await engine.pause_generation(mode="wait")
# By now the request should be done (wait mode waits for completion)
assert request_completed, "Request should have completed during wait"
final_output = gen_task.result()
assert final_output.finished
# Should complete normally, not aborted
assert final_output.outputs[0].finish_reason != "eos"
@pytest.mark.asyncio
async def test_pause_keep_single_request():
"""Test that mode='keep' freezes a single request and resumes with timing gap."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
token_times: list[tuple[int, float]] = []
pause_duration = 5.0
pause_token_idx = 0
async def generator_task():
"""Generate tokens and record timestamps."""
async for output in engine.generate(
request_id="test-keep-single",
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
token_count = len(output.outputs[0].token_ids)
token_times.append((token_count, time.monotonic()))
return output
async def controller_task():
"""Pause and resume the engine."""
nonlocal pause_token_idx
# Wait for some tokens (event-driven, handles slow token generation)
while len(token_times) < 5:
await asyncio.sleep(0.01)
# Pause with keep mode
await engine.pause_generation(mode="keep")
pause_token_idx = len(token_times)
# Sleep while paused
await asyncio.sleep(pause_duration)
# Resume
await engine.resume_generation()
# Run both tasks with timeout for slow generation
gen_task = asyncio.create_task(generator_task())
ctrl_task = asyncio.create_task(controller_task())
final_output, _ = await asyncio.wait_for(
asyncio.gather(gen_task, ctrl_task), timeout=60.0
)
# Request should complete with all tokens
assert final_output.finished
assert len(final_output.outputs[0].token_ids) == 30
# Check the gap at the recorded pause index matches the pause duration
pause_gap = (
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
)
assert pause_gap >= pause_duration * 0.8, (
f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
)
@pytest.mark.asyncio
async def test_pause_keep_multi_request():
"""Test that mode='keep' freezes multiple concurrent requests and all resume."""
with ExitStack() as after:
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
after.callback(engine.shutdown)
num_requests = 3
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
completed_requests: list[str] = []
any_token_generated = asyncio.Event()
async def gen_multi(request_id: str):
async for out in engine.generate(
request_id=request_id,
prompt=TEXT_PROMPT,
sampling_params=sampling_params,
):
any_token_generated.set()
completed_requests.append(request_id)
return out
# Start multiple requests
tasks = [
asyncio.create_task(gen_multi(f"req-multi-{i}"))
for i in range(num_requests)
]
# Wait for at least one token across any request (event-driven)
await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)
# Pause with keep mode
await engine.pause_generation(mode="keep")
# Wait while paused
await asyncio.sleep(0.5)
# Resume
await engine.resume_generation()
# All requests should complete
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)
assert len(completed_requests) == num_requests
for result in results:
assert result.finished
assert len(result.outputs[0].token_ids) == 10