[BugFix] Avoid premature async generator exit and raise all exception variations (#7698)

This commit is contained in:
Nick Hill
2024-08-21 11:45:55 -04:00
committed by GitHub
parent dd3fa0e430
commit c75363fbc0
2 changed files with 101 additions and 21 deletions

View File

@@ -2,8 +2,8 @@ import asyncio
import time
from dataclasses import dataclass
from functools import partial
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from typing_extensions import assert_never
@@ -85,9 +85,8 @@ class AsyncStream:
def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
Exception]) -> None:
if self._finished:
return
self._queue.put_nowait(item)
if not self._finished:
self._queue.put_nowait(item)
def finish(
self,
@@ -96,7 +95,7 @@ class AsyncStream:
if not self._finished:
self._finished = True
self._queue.put_nowait(
exception if exception is not None else STOP_ITERATION)
exception if self._is_raisable(exception) else STOP_ITERATION)
@property
def finished(self) -> bool:
@@ -106,9 +105,9 @@ class AsyncStream:
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while not self._finished:
while True:
result = await self._queue.get()
if isinstance(result, Exception):
if self._is_raisable(result):
if result == STOP_ITERATION:
return
raise result
@@ -117,6 +116,12 @@ class AsyncStream:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
@staticmethod
def _is_raisable(value: Any):
return isinstance(value, BaseException) or \
(isinstance(value, type) and \
issubclass(value, BaseException))
class RequestTracker:
"""Synchronous abstraction for tracking requests."""