Make AsyncLLMEngine more robust & fix batched abort (#969)

Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
This commit is contained in:
Antoni Baum
2023-09-07 13:43:45 -07:00
committed by GitHub
parent 7a9c20c715
commit c07ece5ca4
7 changed files with 345 additions and 55 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -14,6 +14,28 @@ from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
class AsyncStream:
"""A stream of RequestOutputs for a request that can be
iterated over asynchronously."""
@@ -43,15 +65,90 @@ class AsyncStream:
result = await self._queue.get()
if result is StopIteration:
raise StopAsyncIteration
elif isinstance(result, Exception):
raise result
return result
def _raise_exception_on_finish(task: asyncio.Task) -> None:
try:
task.result()
except Exception as e:
raise RuntimeError("Task finished unexpectedly.") from e
raise RuntimeError("Task finished unexpectedly.")
class RequestTracker:
"""Synchronous abstraction for tracking requests."""
def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
def __contains__(self, item):
return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None:
"""Propagate an exception to all request streams."""
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
*,
verbose: bool = False) -> None:
"""Process a request output from the engine."""
request_id = request_output.request_id
self._request_streams[request_id].put(request_output)
if request_output.finished:
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)
def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
if request_id in self._request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self._new_requests.put_nowait((stream, {
"request_id": request_id,
**engine_add_request_kwargs
}))
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
logger.info(f"Aborted request {request_id}.")
self._finished_requests.put_nowait(request_id)
if request_id not in self._request_streams or self._request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
self._request_streams[request_id].finish()
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests: List[dict] = []
finished_requests: Set[str] = set()
while not self._finished_requests.empty():
request_id = self._finished_requests.get_nowait()
finished_requests.add(request_id)
self._request_streams.pop(request_id, None)
while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
# The request has already been aborted.
stream.finish()
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
return new_requests, finished_requests
class _AsyncLLMEngine(LLMEngine):
@@ -150,16 +247,15 @@ class AsyncLLMEngine:
self.log_requests = log_requests
self.engine = self._init_engine(*args, **kwargs)
# Request id -> stream.
self.request_streams: Dict[str, AsyncStream] = {}
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None
if start_engine_loop:
self.start_background_loop()
@property
def is_running(self) -> bool:
return self.background_loop is not None
return (self.background_loop is not None
and not self.background_loop.done())
def start_background_loop(self) -> None:
"""Start the background loop."""
@@ -167,7 +263,9 @@ class AsyncLLMEngine:
raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop())
self.background_loop.add_done_callback(_raise_exception_on_finish)
self.background_loop.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self.request_tracker))
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@@ -181,6 +279,21 @@ class AsyncLLMEngine:
async def engine_step(self):
"""Kick the engine to process the waiting requests."""
new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
else:
self.engine.add_request(**new_request)
if finished_requests:
await self._engine_abort(finished_requests)
if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
else:
@@ -188,20 +301,8 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
request_id = request_output.request_id
self.request_streams[request_id].put(request_output)
if request_output.finished:
if self.log_requests:
logger.info(f"Finished request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.put_nowait(request_id)
finished_request = set()
while not self.finished_requests.empty():
finished_request.add(self.finished_requests.get_nowait())
await self._engine_abort(finished_request)
for request_id in finished_request:
del self.request_streams[request_id]
self.request_tracker.process_request_output(
request_output, verbose=self.log_requests)
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
@@ -228,25 +329,19 @@ class AsyncLLMEngine:
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")
if request_id in self.request_streams:
raise KeyError(f"Request {request_id} already exists.")
stream = AsyncStream(request_id)
self.request_streams[request_id] = stream
if not self.is_running:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
# Add the request into the vLLM engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
stream = self.request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
return stream
@@ -300,6 +395,13 @@ class AsyncLLMEngine:
Args:
request_id: The unique id of the request.
"""
if not self.is_running:
raise AsyncEngineDeadError(
"Background loop is not running. If it was running, "
"inspect the output to find the stacktrace of the "
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
return self._abort(request_id)
def _abort(self, request_id: str) -> None:
@@ -311,16 +413,8 @@ class AsyncLLMEngine:
Args:
request_id: The unique id of the request.
"""
if request_id not in self.request_streams or self.request_streams[
request_id].finished:
# The request has already finished or been aborted.
return
if self.log_requests:
logger.info(f"Aborted request {request_id}.")
self.request_streams[request_id].finish()
self.finished_requests.put_nowait(request_id)
self.request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""