[BugFix] Overhaul async request cancellation (#7111)
This commit is contained in:
111
vllm/utils.py
111
vllm/utils.py
@@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import enum
|
||||
import gc
|
||||
@@ -11,10 +12,11 @@ import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from asyncio import FIRST_COMPLETED, ensure_future
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
@@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
class ProducerFinished:
|
||||
pass
|
||||
async def iterate_with_cancellation(
|
||||
iterator: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Convert async iterator into one that polls the provided function
|
||||
at least once per second to check for client cancellation.
|
||||
"""
|
||||
|
||||
# Can use anext() in python >= 3.10
|
||||
awaits = [ensure_future(iterator.__anext__())]
|
||||
while True:
|
||||
done, pending = await asyncio.wait(awaits, timeout=1)
|
||||
if await is_cancelled():
|
||||
with contextlib.suppress(BaseException):
|
||||
awaits[0].cancel()
|
||||
await iterator.aclose()
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
if done:
|
||||
try:
|
||||
item = await awaits[0]
|
||||
awaits[0] = ensure_future(iterator.__anext__())
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
# we are done
|
||||
return
|
||||
|
||||
|
||||
def merge_async_iterators(
|
||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
|
||||
It also polls the provided function at least once per second to check
|
||||
for client cancellation.
|
||||
"""
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
|
||||
Exception]] = asyncio.Queue()
|
||||
|
||||
producers = len(iterators)
|
||||
|
||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
||||
try:
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
# Signal to the consumer that we've finished
|
||||
await queue.put(ProducerFinished())
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
remaining = producers
|
||||
try:
|
||||
while remaining or not queue.empty():
|
||||
# we think there is a race condition here
|
||||
item = await queue.get()
|
||||
|
||||
if isinstance(item, ProducerFinished):
|
||||
# Signal that a producer finished- not a real item
|
||||
remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
for task in _tasks:
|
||||
if sys.version_info >= (3, 9):
|
||||
# msg parameter only supported in Python 3.9+
|
||||
task.cancel(e)
|
||||
else:
|
||||
task.cancel()
|
||||
raise e
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
# Can use anext() in python >= 3.10
|
||||
awaits = {
|
||||
ensure_future(pair[1].__anext__()): pair
|
||||
for pair in enumerate(iterators)
|
||||
}
|
||||
try:
|
||||
while awaits:
|
||||
done, pending = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED,
|
||||
timeout=1)
|
||||
if await is_cancelled():
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
item = await d
|
||||
i, it = pair
|
||||
awaits[ensure_future(it.__anext__())] = pair
|
||||
yield i, item
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
finally:
|
||||
# Cancel any remaining iterators
|
||||
for f, (_, it) in awaits.items():
|
||||
with contextlib.suppress(BaseException):
|
||||
f.cancel()
|
||||
await it.aclose()
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
|
||||
Reference in New Issue
Block a user