Make AsyncLLMEngine more robust & fix batched abort (#969)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com> Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Type, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@@ -14,6 +14,28 @@ from vllm.sampling_params import SamplingParams
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
@@ -43,15 +65,90 @@ class AsyncStream:
|
||||
result = await self._queue.get()
|
||||
if result is StopIteration:
|
||||
raise StopAsyncIteration
|
||||
elif isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
except Exception as e:
|
||||
raise RuntimeError("Task finished unexpectedly.") from e
|
||||
raise RuntimeError("Task finished unexpectedly.")
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def propagate_exception(self, exc: Exception) -> None:
|
||||
"""Propagate an exception to all request streams."""
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
|
||||
self._request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
@@ -150,16 +247,15 @@ class AsyncLLMEngine:
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
# Request id -> stream.
|
||||
self.request_streams: Dict[str, AsyncStream] = {}
|
||||
self.finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self.request_tracker: RequestTracker = RequestTracker()
|
||||
self.background_loop = None
|
||||
if start_engine_loop:
|
||||
self.start_background_loop()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self.background_loop is not None
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
@@ -167,7 +263,9 @@ class AsyncLLMEngine:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self.background_loop = asyncio.get_event_loop().create_task(
|
||||
self.run_engine_loop())
|
||||
self.background_loop.add_done_callback(_raise_exception_on_finish)
|
||||
self.background_loop.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self.request_tracker))
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
@@ -181,6 +279,21 @@ class AsyncLLMEngine:
|
||||
|
||||
async def engine_step(self):
|
||||
"""Kick the engine to process the waiting requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self.request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
self.engine.add_request(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
@@ -188,20 +301,8 @@ class AsyncLLMEngine:
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
self.request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if self.log_requests:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
|
||||
finished_request = set()
|
||||
while not self.finished_requests.empty():
|
||||
finished_request.add(self.finished_requests.get_nowait())
|
||||
await self._engine_abort(finished_request)
|
||||
for request_id in finished_request:
|
||||
del self.request_streams[request_id]
|
||||
self.request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
@@ -228,25 +329,19 @@ class AsyncLLMEngine:
|
||||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
if request_id in self.request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
stream = AsyncStream(request_id)
|
||||
self.request_streams[request_id] = stream
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
else:
|
||||
self.engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
stream = self.request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
return stream
|
||||
|
||||
@@ -300,6 +395,13 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: str) -> None:
|
||||
@@ -311,16 +413,8 @@ class AsyncLLMEngine:
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if request_id not in self.request_streams or self.request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
if self.log_requests:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self.request_streams[request_id].finish()
|
||||
self.finished_requests.put_nowait(request_id)
|
||||
self.request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
|
||||
Reference in New Issue
Block a user