[Core] Move pause and resume functions into engine (#34125)
Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Aaron Hao <ahao@anyscale.com> Signed-off-by: hao-aaron <ahao@anyscale.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm import SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
@@ -181,3 +182,145 @@ async def test_load(
|
||||
assert slogger.finished_req_count > NUM_REQUESTS // (DP_SIZE + 1), (
|
||||
f"requests are imbalanced: {stats_loggers}"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DP Pause/Resume Tests
|
||||
# =============================================================================
|
||||
|
||||
DP_PAUSE_MODEL = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
DP_PAUSE_PROMPT = "This is a test of data parallel pause"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_resume_basic():
|
||||
"""Pausing from the client (one call) pauses all DP ranks; resume clears it."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
assert not await engine.is_paused()
|
||||
await engine.pause_generation(mode="abort")
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# Engine still works after resume
|
||||
sampling_params = SamplingParams(max_tokens=5)
|
||||
async for out in engine.generate(
|
||||
request_id="after-resume",
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
pass
|
||||
assert out.finished
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_abort():
|
||||
"""Pause with abort from one client aborts in-flight requests on all DP ranks."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
# Start several requests so they are distributed across ranks
|
||||
sampling_params = SamplingParams(max_tokens=500, ignore_eos=True)
|
||||
num_requests = 4
|
||||
outputs_by_id: dict[str, list[RequestOutput]] = {}
|
||||
|
||||
async def gen(rid: str):
|
||||
out_list: list[RequestOutput] = []
|
||||
outputs_by_id[rid] = out_list
|
||||
async for out in engine.generate(
|
||||
request_id=rid,
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
):
|
||||
out_list.append(out)
|
||||
return out_list[-1] if out_list else None
|
||||
|
||||
tasks = [asyncio.create_task(gen(f"req-{i}")) for i in range(num_requests)]
|
||||
# Wait for some tokens on at least one request
|
||||
while not any(len(o) >= 2 for o in outputs_by_id.values()):
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
await engine.pause_generation(mode="abort")
|
||||
|
||||
finals = await asyncio.gather(*tasks)
|
||||
for i, final in enumerate(finals):
|
||||
assert final is not None, f"req-{i} had no output"
|
||||
assert final.finished
|
||||
assert final.outputs[0].finish_reason == "abort"
|
||||
|
||||
assert await engine.is_paused()
|
||||
await engine.resume_generation()
|
||||
assert not await engine.is_paused()
|
||||
|
||||
# New request completes after resume
|
||||
async for out in engine.generate(
|
||||
request_id="after-abort",
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
):
|
||||
pass
|
||||
assert out.finished
|
||||
assert not engine.output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dp_pause_keep_then_resume():
|
||||
"""Pause with keep queues new requests; resume allows them to run."""
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip("DP pause tests use mp backend only")
|
||||
with ExitStack() as after:
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=DP_PAUSE_MODEL,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
data_parallel_backend="mp",
|
||||
)
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.pause_generation(mode="keep")
|
||||
assert await engine.is_paused()
|
||||
|
||||
request_done = asyncio.Event()
|
||||
|
||||
async def gen():
|
||||
async for out in engine.generate(
|
||||
request_id="queued-keep",
|
||||
prompt=DP_PAUSE_PROMPT,
|
||||
sampling_params=SamplingParams(max_tokens=5),
|
||||
):
|
||||
pass
|
||||
request_done.set()
|
||||
return out
|
||||
|
||||
task = asyncio.create_task(gen())
|
||||
await asyncio.sleep(0.2)
|
||||
assert not request_done.is_set()
|
||||
|
||||
await engine.resume_generation()
|
||||
final = await asyncio.wait_for(task, timeout=10.0)
|
||||
assert final.finished
|
||||
assert not await engine.is_paused()
|
||||
|
||||
Reference in New Issue
Block a user