[Bugfix] AsyncLLMEngine hangs with asyncio.run (#5654)

This commit is contained in:
zifeitong
2024-06-19 13:57:12 -07:00
committed by GitHub
parent d571ca0108
commit 78687504f7
5 changed files with 271 additions and 47 deletions

View File

@@ -2,8 +2,12 @@ import asyncio
from dataclasses import dataclass
import pytest
import torch
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm import SamplingParams
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from ..utils import wait_for_gpu_memory_to_clear
@dataclass
@@ -94,3 +98,35 @@ async def test_new_requests_event():
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None
def test_asyncio_run():
wait_for_gpu_memory_to_clear(
devices=list(range(torch.cuda.device_count())),
threshold_bytes=2 * 2**30,
timeout_s=60,
)
engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m"))
async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
)
async for output in engine.generate(prompt,
sampling_params,
request_id=prompt):
final_output = output
return final_output
async def generate():
return await asyncio.gather(
run("test0"),
run("test1"),
)
results = asyncio.run(generate())
assert len(results) == 2