[Core] Move pause and resume functions into engine (#34125)

Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: Aaron Hao <ahao@anyscale.com>
Signed-off-by: hao-aaron <ahao@anyscale.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Aaron Hao
2026-02-13 00:15:10 -08:00
committed by GitHub
parent 47e9b63e1a
commit dddbff4624
9 changed files with 621 additions and 136 deletions

View File

@@ -12,6 +12,7 @@ from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
@@ -181,3 +182,145 @@ async def test_load(
assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), (
f"requests are imbalanced: {stats_loggers}"
)
# =============================================================================
# DP Pause/Resume Tests
# =============================================================================
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
@pytest.mark.asyncio
async def test_dp_pause_resume_basic():
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
assert not await engine.is_paused()
await engine.pause_generation(mode="abort")
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# Engine still works after resume
sampling_params = SamplingParams(max_tokens=5)
async for out in engine.generate(
request_id="after-resume",
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
pass
assert out.finished
@pytest.mark.asyncio
async def test_dp_pause_abort():
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
# Start several requests so they are distributed across ranks
sampling_params = SamplingParams(max_tokens=500, ignore_eos=True)
num_requests = 4
outputs_by_id: dict[str, list[RequestOutput]] = {}
async def gen(rid: str):
out_list: list[RequestOutput] = []
outputs_by_id[rid] = out_list
async for out in engine.generate(
request_id=rid,
prompt=DP_PAUSE_PROMPT,
sampling_params=sampling_params,
):
out_list.append(out)
return out_list[-1] if out_list else None
tasks = [asyncio.create_task(gen(f"req-{i}")) for i in range(num_requests)]
# Wait for some tokens on at least one request
while not any(len(o) >= 2 for o in outputs_by_id.values()):
await asyncio.sleep(0.02)
await engine.pause_generation(mode="abort")
finals = await asyncio.gather(*tasks)
for i, final in enumerate(finals):
assert final is not None, f"req-{i} had no output"
assert final.finished
assert final.outputs[0].finish_reason == "abort"
assert await engine.is_paused()
await engine.resume_generation()
assert not await engine.is_paused()
# New request completes after resume
async for out in engine.generate(
request_id="after-abort",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
assert out.finished
assert not engine.output_processor.has_unfinished_requests()
@pytest.mark.asyncio
async def test_dp_pause_keep_then_resume():
"""Pause with keep queues new requests; resume allows them to run."""
if current_platform.is_rocm():
pytest.skip("DP pause tests use mp backend only")
with ExitStack() as after:
engine_args = AsyncEngineArgs(
model=DP_PAUSE_MODEL,
enforce_eager=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=DP_SIZE,
data_parallel_backend="mp",
)
engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown)
await engine.pause_generation(mode="keep")
assert await engine.is_paused()
request_done = asyncio.Event()
async def gen():
async for out in engine.generate(
request_id="queued-keep",
prompt=DP_PAUSE_PROMPT,
sampling_params=SamplingParams(max_tokens=5),
):
pass
request_done.set()
return out
task = asyncio.create_task(gen())
await asyncio.sleep(0.2)
assert not request_done.is_set()
await engine.resume_generation()
final = await asyncio.wait_for(task, timeout=10.0)
assert final.finished
assert not await engine.is_paused()