[Feat][RL] Pause and Resume with keep requests for single engine (#32351)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -661,3 +662,301 @@ async def collect_outputs(
|
||||
outputs_list.append(output)
|
||||
final_output = output
|
||||
return final_output
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pause/Resume Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_resume_basic():
|
||||
"""Test basic pause/resume flag behavior and idempotency.
|
||||
|
||||
Tests:
|
||||
- pause_generation sets the paused flag
|
||||
- resume_generation clears the paused flag
|
||||
- calling pause when already paused is a no-op
|
||||
- calling resume when not paused is safe
|
||||
- all pause modes work with no requests in flight
|
||||
- rapid pause/resume cycles don't break the engine
|
||||
"""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Initially not paused
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Resume when not paused should be safe
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Pause sets flag
|
||||
await engine.pause_generation(mode="abort")
|
||||
assert await engine.is_paused()
|
||||
|
||||
# Pause when already paused is a no-op
|
||||
await engine.pause_generation(mode="abort")
|
||||
assert await engine.is_paused()
|
||||
|
||||
# Resume clears flag
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Test all modes with no requests in flight
|
||||
for mode in ("abort", "wait", "keep"):
|
||||
await engine.pause_generation(mode=mode)
|
||||
# "keep" only freezes the scheduler; it does not set _paused
|
||||
if mode != "keep":
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Concurrent pause/resume race conditions - should not deadlock or raise
|
||||
await asyncio.gather(
|
||||
engine.pause_generation(mode="abort"),
|
||||
engine.resume_generation(),
|
||||
engine.pause_generation(mode="abort"),
|
||||
engine.resume_generation(),
|
||||
)
|
||||
|
||||
# Ensure we end in a known state
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Engine should still work after all cycles
|
||||
sampling_params = SamplingParams(max_tokens=5)
|
||||
async for out in engine.generate(
|
||||
request_id="post-cycles",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
pass
|
||||
assert out.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_abort():
|
||||
"""Test that mode='abort' aborts in-flight requests immediately."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Start a long-running request
|
||||
sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
|
||||
outputs: list[RequestOutput] = []
|
||||
|
||||
async def gen():
|
||||
async for out in engine.generate(
|
||||
request_id="test-abort-pause",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
outputs.append(out)
|
||||
return outputs[-1] if outputs else None
|
||||
|
||||
# Start generation task
|
||||
gen_task = asyncio.create_task(gen())
|
||||
|
||||
# Wait for some tokens to be generated
|
||||
while len(outputs) < 3:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Pause with abort mode
|
||||
await engine.pause_generation(mode="abort")
|
||||
|
||||
# Wait for task to complete (should be aborted)
|
||||
final_output = await gen_task
|
||||
|
||||
# Request should be finished (aborted)
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
assert final_output.outputs[0].finish_reason == "abort"
|
||||
|
||||
# Also test that new requests are blocked while paused, then resume
|
||||
assert await engine.is_paused()
|
||||
|
||||
request_completed = False
|
||||
|
||||
async def gen_blocked():
|
||||
nonlocal request_completed
|
||||
async for out in engine.generate(
|
||||
request_id="test-blocked",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
):
|
||||
pass
|
||||
request_completed = True
|
||||
return out
|
||||
|
||||
# Start a request (should block)
|
||||
gen_task2 = asyncio.create_task(gen_blocked())
|
||||
|
||||
# Wait a bit - request should not have completed
|
||||
await asyncio.sleep(0.3)
|
||||
assert not request_completed, "Request should be blocked while paused"
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# Now request should complete
|
||||
final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
|
||||
assert request_completed
|
||||
assert final_output2.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_wait():
|
||||
"""Test that mode='wait' waits for in-flight requests to complete."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Start a request - use fewer tokens since wait mode waits for completion
|
||||
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
|
||||
got_first_token = asyncio.Event()
|
||||
request_completed = False
|
||||
|
||||
async def gen():
|
||||
nonlocal request_completed
|
||||
async for out in engine.generate(
|
||||
request_id="test-wait",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
got_first_token.set()
|
||||
request_completed = True
|
||||
return out
|
||||
|
||||
# Start generation
|
||||
gen_task = asyncio.create_task(gen())
|
||||
|
||||
# Wait for generation to start (event-driven)
|
||||
await asyncio.wait_for(got_first_token.wait(), timeout=30.0)
|
||||
|
||||
# Pause with wait mode - should wait for request to finish
|
||||
await engine.pause_generation(mode="wait")
|
||||
|
||||
# By now the request should be done (wait mode waits for completion)
|
||||
assert request_completed, "Request should have completed during wait"
|
||||
|
||||
final_output = gen_task.result()
|
||||
assert final_output.finished
|
||||
# Should complete normally, not aborted
|
||||
assert final_output.outputs[0].finish_reason != "eos"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_keep_single_request():
|
||||
"""Test that mode='keep' freezes a single request and resumes with timing gap."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
|
||||
token_times: list[tuple[int, float]] = []
|
||||
pause_duration = 5.0
|
||||
pause_token_idx = 0
|
||||
|
||||
async def generator_task():
|
||||
"""Generate tokens and record timestamps."""
|
||||
async for output in engine.generate(
|
||||
request_id="test-keep-single",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
token_count = len(output.outputs[0].token_ids)
|
||||
token_times.append((token_count, time.monotonic()))
|
||||
return output
|
||||
|
||||
async def controller_task():
|
||||
"""Pause and resume the engine."""
|
||||
nonlocal pause_token_idx
|
||||
# Wait for some tokens (event-driven, handles slow token generation)
|
||||
while len(token_times) < 5:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Pause with keep mode
|
||||
await engine.pause_generation(mode="keep")
|
||||
pause_token_idx = len(token_times)
|
||||
|
||||
# Sleep while paused
|
||||
await asyncio.sleep(pause_duration)
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# Run both tasks with timeout for slow generation
|
||||
gen_task = asyncio.create_task(generator_task())
|
||||
ctrl_task = asyncio.create_task(controller_task())
|
||||
|
||||
final_output, _ = await asyncio.wait_for(
|
||||
asyncio.gather(gen_task, ctrl_task), timeout=60.0
|
||||
)
|
||||
|
||||
# Request should complete with all tokens
|
||||
assert final_output.finished
|
||||
assert len(final_output.outputs[0].token_ids) == 30
|
||||
|
||||
# Check the gap at the recorded pause index matches the pause duration
|
||||
pause_gap = (
|
||||
token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
|
||||
)
|
||||
assert pause_gap >= pause_duration * 0.8, (
|
||||
f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_keep_multi_request():
|
||||
"""Test that mode='keep' freezes multiple concurrent requests and all resume."""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
num_requests = 3
|
||||
sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
|
||||
completed_requests: list[str] = []
|
||||
any_token_generated = asyncio.Event()
|
||||
|
||||
async def gen_multi(request_id: str):
|
||||
async for out in engine.generate(
|
||||
request_id=request_id,
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
any_token_generated.set()
|
||||
completed_requests.append(request_id)
|
||||
return out
|
||||
|
||||
# Start multiple requests
|
||||
tasks = [
|
||||
asyncio.create_task(gen_multi(f"req-multi-{i}"))
|
||||
for i in range(num_requests)
|
||||
]
|
||||
|
||||
# Wait for at least one token across any request (event-driven)
|
||||
await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)
|
||||
|
||||
# Pause with keep mode
|
||||
await engine.pause_generation(mode="keep")
|
||||
|
||||
# Wait while paused
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Resume
|
||||
await engine.resume_generation()
|
||||
|
||||
# All requests should complete
|
||||
results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)
|
||||
|
||||
assert len(completed_requests) == num_requests
|
||||
for result in results:
|
||||
assert result.finished
|
||||
assert len(result.outputs[0].token_ids) == 10
|
||||
|
||||
Reference in New Issue
Block a user