[BugFix] Overhaul async request cancellation (#7111)

This commit is contained in:
Nick Hill
2024-08-06 22:21:41 -07:00
committed by GitHub
parent f9a5600649
commit 9a3f49ae07
11 changed files with 222 additions and 222 deletions

View File

@@ -1,5 +1,6 @@
import argparse
import asyncio
import contextlib
import datetime
import enum
import gc
@@ -11,10 +12,11 @@ import tempfile
import threading
import uuid
import warnings
from asyncio import FIRST_COMPLETED, ensure_future
from collections import defaultdict
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union, overload)
@@ -373,63 +375,74 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[T, None]:
"""Convert async iterator into one that polls the provided function
at least once per second to check for client cancellation.
"""
# Can use anext() in python >= 3.10
awaits = [ensure_future(iterator.__anext__())]
while True:
done, pending = await asyncio.wait(awaits, timeout=1)
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
if done:
try:
item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__())
yield item
except StopAsyncIteration:
# we are done
return
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
) -> AsyncGenerator[Tuple[int, T], None]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
It also polls the provided function at least once per second to check
for client cancellation.
"""
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
remaining = producers
try:
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
except (Exception, asyncio.CancelledError) as e:
for task in _tasks:
if sys.version_info >= (3, 9):
# msg parameter only supported in Python 3.9+
task.cancel(e)
else:
task.cancel()
raise e
await asyncio.gather(*_tasks)
return consumer()
# Can use anext() in python >= 3.10
awaits = {
ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators)
}
try:
while awaits:
done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED,
timeout=1)
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[ensure_future(it.__anext__())] = pair
yield i, item
except StopAsyncIteration:
pass
finally:
# Cancel any remaining iterators
for f, (_, it) in awaits.items():
with contextlib.suppress(BaseException):
f.cancel()
await it.aclose()
def get_ip() -> str: