Add health check, make async Engine more robust (#3015)

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Antoni Baum
2024-03-04 14:01:40 -08:00
committed by GitHub
parent 22de45235c
commit ff578cae54
4 changed files with 138 additions and 65 deletions

View File

@@ -1,8 +1,9 @@
import asyncio
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
Union, AsyncIterator, Callable)
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
@@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = int(
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
class AsyncEngineDeadError(RuntimeError):
pass
def _raise_exception_on_finish(task: asyncio.Task,
request_tracker: "RequestTracker") -> None:
def _raise_exception_on_finish(
task: asyncio.Task, error_callback: Callable[[Exception],
None]) -> None:
msg = ("Task finished unexpectedly. This should never happen! "
"Please open an issue on Github.")
exception = None
try:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from exc
task.result()
# NOTE: This will be thrown if task exits normally (which it should not)
raise AsyncEngineDeadError(msg)
except Exception as exc:
request_tracker.propagate_exception(exc)
raise exc
except Exception as e:
exception = e
logger.error("Engine background task failed", exc_info=e)
error_callback(exception)
raise AsyncEngineDeadError(
msg + " See stack trace above for the actual cause.") from e
class AsyncStream:
@@ -78,13 +82,13 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
self.new_requests_event = asyncio.Event()
def __contains__(self, item):
return item in self._request_streams
def init_event(self):
self.new_requests_event = asyncio.Event()
def __len__(self) -> int:
return len(self._request_streams)
def propagate_exception(self,
exc: Exception,
@@ -93,9 +97,11 @@ class RequestTracker:
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for stream in self._request_streams.values():
for rid, stream in self._request_streams.items():
stream.put(exc)
self.abort_request(rid)
def process_request_output(self,
request_output: RequestOutput,
@@ -172,12 +178,15 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
if not self.has_new_requests():
await self.new_requests_event.wait()
self.new_requests_event.clear()
def has_new_requests(self):
return not self._new_requests.empty()
class _AsyncLLMEngine(LLMEngine):
@@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
all_outputs = await asyncio.gather(*coros)
return all_outputs
async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
@@ -335,27 +348,48 @@ class AsyncLLMEngine:
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and not self.background_loop.done())
and not self._background_loop_unshielded.done())
@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
and self._background_loop_unshielded.done())
@property
def errored(self) -> bool:
return self._errored_with is not None
def set_errored(self, exc: Exception) -> None:
self._errored_with = exc
def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)
def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
def start_background_loop(self) -> None:
"""Start the background loop."""
if self.errored:
raise AsyncEngineDeadError(
"Background loop has errored already.") from self._errored_with
if self.is_running:
raise RuntimeError("Background loop is already running.")
self._request_tracker.init_event()
# Initialize the RequestTracker here so it uses the right event loop.
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self._request_tracker))
error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args,
@@ -423,12 +457,23 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
if not has_requests_in_progress:
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
logger.debug("Got new requests!")
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
raise
await asyncio.sleep(0)
async def add_request(
@@ -647,3 +692,19 @@ class AsyncLLMEngine:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()
async def check_health(self):
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
if self.is_stopped:
raise AsyncEngineDeadError("Background loop is stopped.")
if self.engine_use_ray:
try:
await self.engine.check_health.remote()
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
await self.engine.check_health_async()
logger.debug(f"Health check took {time.perf_counter()-t}s")