[Bugfix] AsyncLLMEngine hangs with asyncio.run (#5654)
This commit is contained in:
@@ -2,8 +2,12 @@ import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
|
||||
|
||||
from ..utils import wait_for_gpu_memory_to_clear
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,3 +98,35 @@ async def test_new_requests_event():
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
||||
|
||||
def test_asyncio_run():
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
AsyncEngineArgs(model="facebook/opt-125m"))
|
||||
|
||||
async def run(prompt: str):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
)
|
||||
|
||||
async for output in engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id=prompt):
|
||||
final_output = output
|
||||
return final_output
|
||||
|
||||
async def generate():
|
||||
return await asyncio.gather(
|
||||
run("test0"),
|
||||
run("test1"),
|
||||
)
|
||||
|
||||
results = asyncio.run(generate())
|
||||
assert len(results) == 2
|
||||
|
||||
Reference in New Issue
Block a user