[Core] Move pause and resume functions into engine (#34125)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Signed-off-by: hao-aaron <ahao@anyscale.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -708,9 +708,7 @@ async def test_pause_resume_basic():
|
||||
# Test all modes with no requests in flight
|
||||
for mode in ("abort", "wait", "keep"):
|
||||
await engine.pause_generation(mode=mode)
|
||||
# "keep" only freezes the scheduler; it does not set _paused
|
||||
if mode != "keep":
|
||||
assert await engine.is_paused()
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
@@ -808,6 +806,53 @@ async def test_pause_abort():
|
||||
assert final_output2.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_then_abort_queued_request():
|
||||
"""Test that aborting a request that was submitted while paused (in
|
||||
_paused_adds_queue) aborts it and notifies the client; the request does
|
||||
not run after resume.
|
||||
"""
|
||||
with ExitStack() as after:
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
request_id = "abort-queued-request"
|
||||
sampling_params = SamplingParams(max_tokens=20, ignore_eos=True)
|
||||
outputs: list[RequestOutput] = []
|
||||
|
||||
# Pause first so the next add goes to _paused_adds_queue
|
||||
await engine.pause_generation(mode="keep")
|
||||
assert await engine.is_paused()
|
||||
|
||||
async def gen():
|
||||
async for out in engine.generate(
|
||||
request_id=request_id,
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
outputs.append(out)
|
||||
return outputs[-1] if outputs else None
|
||||
|
||||
gen_task = asyncio.create_task(gen())
|
||||
|
||||
# Give the request time to reach the engine and sit in _paused_adds_queue
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Abort the queued request
|
||||
await engine.abort(request_id, internal=False)
|
||||
|
||||
# Resume so the engine can process and deliver the abort output
|
||||
await engine.resume_generation()
|
||||
|
||||
final_output = await asyncio.wait_for(gen_task, timeout=10.0)
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
assert final_output.outputs[0].finish_reason == "abort"
|
||||
# Request was never run, so no tokens
|
||||
assert len(final_output.outputs[0].token_ids) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_wait():
|
||||
"""Test that mode='wait' waits for in-flight requests to complete."""
|
||||
|
||||
Reference in New Issue
Block a user