[Bugfix] AsyncLLMEngine hangs with asyncio.run (#5654)

This commit is contained in:
zifeitong
2024-06-19 13:57:12 -07:00
committed by GitHub
parent d571ca0108
commit 78687504f7
5 changed files with 271 additions and 47 deletions

View File

@@ -1,5 +1,4 @@
import asyncio
import time
from itertools import cycle
from typing import Dict, List, Optional, Tuple, Union
@@ -7,12 +6,6 @@ import pytest
import ray
import torch
from vllm.utils import is_hip
if (not is_hip()):
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
nvmlInit)
from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@@ -26,6 +19,7 @@ from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
from ...conftest import cleanup
from ...utils import wait_for_gpu_memory_to_clear
class AsyncLLM:
@@ -291,38 +285,3 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
def wait_for_gpu_memory_to_clear(devices: List[int],
threshold_bytes: int,
timeout_s: float = 120) -> None:
# Use nvml instead of pytorch to reduce measurement error from torch cuda
# context.
nvmlInit()
start_time = time.time()
while True:
output: Dict[int, str] = {}
output_raw: Dict[int, float] = {}
for device in devices:
dev_handle = nvmlDeviceGetHandleByIndex(device)
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
gb_used = mem_info.used / 2**30
output_raw[device] = gb_used
output[device] = f'{gb_used:.02f}'
print('gpu memory used (GB): ', end='')
for k, v in output.items():
print(f'{k}={v}; ', end='')
print('')
dur_s = time.time() - start_time
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
print(f'Done waiting for free GPU memory on devices {devices=} '
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
break
if dur_s >= timeout_s:
raise ValueError(f'Memory of devices {devices=} not free after '
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
time.sleep(5)