[BugFix] Overhaul async request cancellation (#7111)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
||||
from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
@@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task,
|
||||
"actual cause.") from e
|
||||
|
||||
|
||||
STOP_ITERATION = Exception() # Sentinel
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
|
||||
that can be iterated over asynchronously."""
|
||||
that can be iterated over asynchronously via an async generator."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None:
|
||||
self.request_id = request_id
|
||||
self._cancel = cancel
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
@@ -77,22 +81,30 @@ class AsyncStream:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopAsyncIteration())
|
||||
self._finished = True
|
||||
def finish(self, cancelled: bool = False) -> None:
|
||||
if not self._finished:
|
||||
self._finished = True
|
||||
self._queue.put_nowait(
|
||||
asyncio.CancelledError if cancelled else STOP_ITERATION)
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
async def generator(
|
||||
self
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
try:
|
||||
while not self._finished:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
if result == STOP_ITERATION:
|
||||
return
|
||||
raise result
|
||||
yield result
|
||||
except GeneratorExit:
|
||||
self._cancel(self.request_id)
|
||||
raise asyncio.CancelledError from None
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
@@ -100,7 +112,7 @@ class RequestTracker:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._aborted_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = asyncio.Event()
|
||||
@@ -131,15 +143,21 @@ class RequestTracker:
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
finished = request_output.finished
|
||||
|
||||
if finished:
|
||||
stream = self._request_streams.pop(request_id, None)
|
||||
else:
|
||||
stream = self._request_streams.get(request_id)
|
||||
# Guard against a KeyError which can occur if the request was aborted
|
||||
# while the output was generated
|
||||
if (stream := self._request_streams.get(request_id)) is not None:
|
||||
if stream is not None:
|
||||
stream.put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info("Finished request %s.", request_id)
|
||||
self.abort_request(request_id)
|
||||
if finished:
|
||||
stream.finish()
|
||||
|
||||
if verbose and finished:
|
||||
logger.info("Finished request %s.", request_id)
|
||||
|
||||
def process_exception(self,
|
||||
request_id: str,
|
||||
@@ -162,7 +180,8 @@ class RequestTracker:
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
abort_request = partial(self.abort_request, verbose=verbose)
|
||||
stream = AsyncStream(request_id, abort_request)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
@@ -175,36 +194,36 @@ class RequestTracker:
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
def abort_request(self,
|
||||
request_id: str,
|
||||
*,
|
||||
cancelled: bool = False,
|
||||
verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
self._aborted_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
stream = self._request_streams.pop(request_id, None)
|
||||
if stream is not None:
|
||||
stream.finish(cancelled=cancelled)
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
while not self._aborted_requests.empty():
|
||||
request_id = self._aborted_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
stream.finish(cancelled=True)
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
@@ -556,8 +575,8 @@ class AsyncLLMEngine:
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self._request_tracker.get_new_and_finished_requests())
|
||||
new_requests, aborted_requests = (
|
||||
self._request_tracker.get_new_and_aborted_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
@@ -576,8 +595,8 @@ class AsyncLLMEngine:
|
||||
verbose=self.log_requests,
|
||||
)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
if aborted_requests:
|
||||
await self._engine_abort(aborted_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote() # type: ignore
|
||||
@@ -666,6 +685,8 @@ class AsyncLLMEngine:
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# This method does not need to be async, but kept that way
|
||||
# for backwards compatibility.
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -675,7 +696,7 @@ class AsyncLLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
@@ -686,20 +707,17 @@ class AsyncLLMEngine:
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
verbose=self.log_requests,
|
||||
inputs=inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
arrival_time=arrival_time or time.time(),
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return stream
|
||||
return stream.generator()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
@@ -709,7 +727,7 @@ class AsyncLLMEngine:
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@@ -774,7 +792,7 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self._process_request(
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
sampling_params,
|
||||
@@ -791,7 +809,7 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
@@ -852,7 +870,7 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self._process_request(
|
||||
async for output in await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
pooling_params,
|
||||
@@ -861,37 +879,6 @@ class AsyncLLMEngine:
|
||||
):
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
|
||||
async def _process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
arrival_time = time.time()
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
inputs,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
try:
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
@@ -920,6 +907,7 @@ class AsyncLLMEngine:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self._request_tracker.abort_request(request_id,
|
||||
cancelled=True,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
|
||||
Reference in New Issue
Block a user