Add health check, make async Engine more robust (#3015)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
Union, AsyncIterator, Callable)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
@@ -14,28 +15,31 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = int(
|
||||
os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60"))
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
def _raise_exception_on_finish(
|
||||
task: asyncio.Task, error_callback: Callable[[Exception],
|
||||
None]) -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
|
||||
exception = None
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
task.result()
|
||||
# NOTE: This will be thrown if task exits normally (which it should not)
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
except Exception as e:
|
||||
exception = e
|
||||
logger.error("Engine background task failed", exc_info=e)
|
||||
error_callback(exception)
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from e
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
@@ -78,13 +82,13 @@ class RequestTracker:
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
def __len__(self) -> int:
|
||||
return len(self._request_streams)
|
||||
|
||||
def propagate_exception(self,
|
||||
exc: Exception,
|
||||
@@ -93,9 +97,11 @@ class RequestTracker:
|
||||
(all if request_id is None)."""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
self.abort_request(request_id)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
for rid, stream in self._request_streams.items():
|
||||
stream.put(exc)
|
||||
self.abort_request(rid)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
@@ -172,12 +178,15 @@ class RequestTracker:
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
if not self.has_new_requests():
|
||||
await self.new_requests_event.wait()
|
||||
self.new_requests_event.clear()
|
||||
|
||||
def has_new_requests(self):
|
||||
return not self._new_requests.empty()
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
@@ -285,6 +294,10 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
|
||||
async def check_health_async(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
@@ -335,27 +348,48 @@ class AsyncLLMEngine:
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
self._request_tracker: Optional[RequestTracker] = None
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
and not self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def is_stopped(self) -> bool:
|
||||
return self.errored or (self.background_loop is not None
|
||||
and self._background_loop_unshielded.done())
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
return self._errored_with is not None
|
||||
|
||||
def set_errored(self, exc: Exception) -> None:
|
||||
self._errored_with = exc
|
||||
|
||||
def _error_callback(self, exc: Exception) -> None:
|
||||
self.set_errored(exc)
|
||||
self._request_tracker.propagate_exception(exc)
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.engine.tokenizer.tokenizer
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
if self.errored:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop has errored already.") from self._errored_with
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self._request_tracker.init_event()
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self._request_tracker))
|
||||
error_callback=self._error_callback))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, *args,
|
||||
@@ -423,12 +457,23 @@ class AsyncLLMEngine:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
has_requests_in_progress = False
|
||||
while True:
|
||||
if not has_requests_in_progress:
|
||||
logger.debug("Waiting for new requests...")
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
logger.debug("Got new requests!")
|
||||
|
||||
# Abort if iteration takes too long due to unrecoverable errors
|
||||
# (eg. NCCL timeouts).
|
||||
try:
|
||||
has_requests_in_progress = await asyncio.wait_for(
|
||||
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
self.set_errored(exc)
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
@@ -647,3 +692,19 @@ class AsyncLLMEngine:
|
||||
await self.engine.do_log_stats.remote()
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
|
||||
async def check_health(self):
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
t = time.perf_counter()
|
||||
logger.debug("Starting health check...")
|
||||
if self.is_stopped:
|
||||
raise AsyncEngineDeadError("Background loop is stopped.")
|
||||
|
||||
if self.engine_use_ray:
|
||||
try:
|
||||
await self.engine.check_health.remote()
|
||||
except ray.exceptions.RayActorError as e:
|
||||
raise RuntimeError("Engine is dead.") from e
|
||||
else:
|
||||
await self.engine.check_health_async()
|
||||
logger.debug(f"Health check took {time.perf_counter()-t}s")
|
||||
|
||||
Reference in New Issue
Block a user