[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
import asyncio
|
||||
import os
|
||||
from asyncio import CancelledError
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput as RealRequestOutput
|
||||
|
||||
from ..conftest import cleanup
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
@@ -118,15 +123,38 @@ async def test_new_requests_event():
|
||||
os.environ.pop("VLLM_ALLOW_ENGINE_USE_RAY")
|
||||
|
||||
|
||||
def test_asyncio_run():
|
||||
def start_engine():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m"))
|
||||
return AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def async_engine():
|
||||
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
|
||||
func=start_engine)
|
||||
try:
|
||||
yield engine
|
||||
finally:
|
||||
engine.shutdown_background_loop()
|
||||
del engine
|
||||
await asyncio.sleep(0.1)
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
# So we can share the async engine fixture between these tests
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_asyncio_run(async_engine):
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
@@ -134,17 +162,64 @@ def test_asyncio_run():
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
async for output in engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
async for output in async_engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
final_output = output
|
||||
return final_output
|
||||
|
||||
async def generate():
|
||||
return await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
)
|
||||
|
||||
results = asyncio.run(generate())
|
||||
results = await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_cancellation(async_engine):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
i = 0
|
||||
with pytest.raises(CancelledError):
|
||||
async for output in async_engine.generate("test2",
|
||||
sampling_params,
|
||||
request_id="test2"):
|
||||
assert not output.finished
|
||||
i += 1
|
||||
if i == 5:
|
||||
await async_engine.abort("test2")
|
||||
|
||||
assert i == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_delayed_generator(async_engine):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
)
|
||||
|
||||
stream = async_engine.generate("test3",
|
||||
sampling_params,
|
||||
request_id="test3")
|
||||
i = 0
|
||||
final_output: Optional[RealRequestOutput] = None
|
||||
async for output in stream:
|
||||
final_output = output
|
||||
if i == 0:
|
||||
# wait for generation to complete before consuming
|
||||
# the remaining messages
|
||||
await asyncio.sleep(1)
|
||||
if i < 9:
|
||||
assert not output.finished
|
||||
i += 1
|
||||
|
||||
assert i == 10
|
||||
assert final_output is not None
|
||||
assert len(final_output.outputs[0].token_ids) == 10
|
||||
assert final_output.finished
|
||||
|
||||
Reference in New Issue
Block a user