[BugFix] Overhaul async request cancellation (#7111)

This commit is contained in:
Nick Hill
2024-08-06 22:21:41 -07:00
committed by GitHub
parent f9a5600649
commit 9a3f49ae07
11 changed files with 222 additions and 222 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
"actual cause.") from e
STOP_ITERATION = Exception() # Sentinel
class AsyncStream:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously via an async generator."""
def __init__(self, request_id: str) -> None:
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
self.request_id = request_id
self._cancel = cancel
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False
@@ -77,22 +81,30 @@ class AsyncStream:
return
self._queue.put_nowait(item)
def finish(self) -> None:
self._queue.put_nowait(StopAsyncIteration())
self._finished = True
def finish(self, cancelled: bool = False) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(
asyncio.CancelledError if cancelled else STOP_ITERATION)
@property
def finished(self) -> bool:
return self._finished
def __aiter__(self):
return self
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
result = await self._queue.get()
if isinstance(result, Exception):
raise result
return result
async def generator(
self
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
try:
while not self._finished:
result = await self._queue.get()
if isinstance(result, Exception):
if result == STOP_ITERATION:
return
raise result
yield result
except GeneratorExit:
self._cancel(self.request_id)
raise asyncio.CancelledError from None
class RequestTracker:
@@ -100,7 +112,7 @@ class RequestTracker:
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = asyncio.Event()
@@ -131,15 +143,21 @@ class RequestTracker:
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
finished = request_output.finished
if finished:
stream = self._request_streams.pop(request_id, None)
else:
stream = self._request_streams.get(request_id)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if (stream := self._request_streams.get(request_id)) is not None:
if stream is not None:
stream.put(request_output)
if request_output.finished:
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
if finished:
stream.finish()
if verbose and finished:
logger.info("Finished request %s.", request_id)
def process_exception(self,
request_id: str,
@@ -162,7 +180,8 @@ class RequestTracker:
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
abort_request = partial(self.abort_request, verbose=verbose)
stream = AsyncStream(request_id, abort_request)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
@@ -175,36 +194,36 @@ class RequestTracker:
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
def abort_request(self,
request_id: str,
*,
cancelled: bool = False,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info("Aborted request %s.", request_id)
self._finished_requests.put_nowait(request_id)
self._aborted_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
stream = self._request_streams.pop(request_id, None)
if stream is not None:
stream.finish(cancelled=cancelled)
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[Dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
while not self._aborted_requests.empty():
request_id = self._aborted_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
stream.finish(cancelled=True)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
@@ -556,8 +575,8 @@ class AsyncLLMEngine:
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self._request_tracker.get_new_and_finished_requests())
new_requests, aborted_requests = (
self._request_tracker.get_new_and_aborted_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
@@ -576,8 +595,8 @@ class AsyncLLMEngine:
verbose=self.log_requests,
)
if finished_requests:
await self._engine_abort(finished_requests)
if aborted_requests:
await self._engine_abort(aborted_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
@@ -666,6 +685,8 @@ class AsyncLLMEngine:
raise
await asyncio.sleep(0)
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
self,
request_id: str,
@@ -675,7 +696,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@@ -686,20 +707,17 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
if arrival_time is None:
arrival_time = time.time()
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request)
return stream
return stream.generator()
async def generate(
self,
@@ -709,7 +727,7 @@ class AsyncLLMEngine:
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -774,7 +792,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
sampling_params,
@@ -791,7 +809,7 @@ class AsyncLLMEngine:
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
@@ -852,7 +870,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self._process_request(
async for output in await self.add_request(
request_id,
inputs,
pooling_params,
@@ -861,37 +879,6 @@ class AsyncLLMEngine:
):
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def _process_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time = time.time()
stream = await self.add_request(
request_id,
inputs,
params,
arrival_time=arrival_time,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
try:
async for request_output in stream:
yield request_output
except (Exception, asyncio.CancelledError) as e:
self._abort(request_id)
raise e
async def abort(self, request_id: str) -> None:
"""Abort a request.
@@ -920,6 +907,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
cancelled=True,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig: