[Bugfix] Fix request cancellation without polling (#11190)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
Some checks failed
Create Release / Create Release (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (11.8, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.10, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.11, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.12, 2.4.0) (push) Has been cancelled
Create Release / Build Wheel (12.1, ubuntu-20.04, 3.9, 2.4.0) (push) Has been cancelled
This commit is contained in:
@@ -20,7 +20,7 @@ import time
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
@@ -370,72 +370,23 @@ def _next_task(iterator: AsyncGenerator[T, None],
|
||||
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def iterate_with_cancellation(
|
||||
iterator: AsyncGenerator[T, None],
|
||||
is_cancelled: Callable[[], Awaitable[bool]],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Convert async iterator into one that polls the provided function
|
||||
at least once per second to check for client cancellation.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits: List[Future[T]] = [_next_task(iterator, loop)]
|
||||
next_cancel_check: float = 0
|
||||
while True:
|
||||
done, pending = await asyncio.wait(awaits, timeout=1.5)
|
||||
|
||||
# Check for cancellation at most once per second
|
||||
time_now = time.time()
|
||||
if time_now >= next_cancel_check:
|
||||
if await is_cancelled():
|
||||
with contextlib.suppress(BaseException):
|
||||
awaits[0].cancel()
|
||||
await iterator.aclose()
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
next_cancel_check = time_now + 1
|
||||
|
||||
if done:
|
||||
try:
|
||||
item = await awaits[0]
|
||||
awaits[0] = _next_task(iterator, loop)
|
||||
yield item
|
||||
except StopAsyncIteration:
|
||||
# we are done
|
||||
return
|
||||
|
||||
|
||||
async def merge_async_iterators(
|
||||
*iterators: AsyncGenerator[T, None],
|
||||
is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
|
||||
) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
*iterators: AsyncGenerator[T,
|
||||
None], ) -> AsyncGenerator[Tuple[int, T], None]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
|
||||
It also optionally polls a provided function at least once per second
|
||||
to check for client cancellation.
|
||||
"""
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
|
||||
timeout = None if is_cancelled is None else 1.5
|
||||
next_cancel_check: float = 0
|
||||
try:
|
||||
while awaits:
|
||||
done, pending = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED,
|
||||
timeout=timeout)
|
||||
if is_cancelled is not None:
|
||||
# Check for cancellation at most once per second
|
||||
time_now = time.time()
|
||||
if time_now >= next_cancel_check:
|
||||
if await is_cancelled():
|
||||
raise asyncio.CancelledError("client cancelled")
|
||||
next_cancel_check = time_now + 1
|
||||
done, _ = await asyncio.wait(awaits.keys(),
|
||||
return_when=FIRST_COMPLETED)
|
||||
for d in done:
|
||||
pair = awaits.pop(d)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user