Remove AsyncLLMEngine busy loop, shield background task (#1059)

This commit is contained in:
Antoni Baum
2023-09-17 00:29:08 -07:00
committed by GitHub
parent e3e79e9e8a
commit ff36139ffc
4 changed files with 154 additions and 18 deletions

View File

@@ -1,7 +1,8 @@
import asyncio
import time
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -78,14 +79,24 @@ class RequestTracker:
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
self.new_requests_event = None
def __contains__(self, item):
return item in self._request_streams
def propagate_exception(self, exc: Exception) -> None:
"""Propagate an exception to all request streams."""
for stream in self._request_streams.values():
stream.put(exc)
def init_event(self):
self.new_requests_event = asyncio.Event()
def propagate_exception(self,
exc: Exception,
request_id: Optional[str] = None) -> None:
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
else:
for stream in self._request_streams.values():
stream.put(exc)
def process_request_output(self,
request_output: RequestOutput,
@@ -112,6 +123,9 @@ class RequestTracker:
"request_id": request_id,
**engine_add_request_kwargs
}))
self.new_requests_event.set()
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@@ -148,8 +162,13 @@ class RequestTracker:
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
self.new_requests_event.clear()
return new_requests, finished_requests
async def wait_for_new_requests(self):
await self.new_requests_event.wait()
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
@@ -251,9 +270,13 @@ class AsyncLLMEngine:
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.request_tracker: RequestTracker = RequestTracker()
self.background_loop = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self.start_engine_loop = start_engine_loop
self._request_tracker = RequestTracker()
@property
def is_running(self) -> bool:
@@ -264,11 +287,14 @@ class AsyncLLMEngine:
"""Start the background loop."""
if self.is_running:
raise RuntimeError("Background loop is already running.")
self.background_loop = asyncio.get_event_loop().create_task(
self.run_engine_loop())
self.background_loop.add_done_callback(
self._request_tracker.init_event()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
self._background_loop_unshielded.add_done_callback(
partial(_raise_exception_on_finish,
request_tracker=self.request_tracker))
request_tracker=self._request_tracker))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
def _init_engine(self, *args,
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
@@ -280,11 +306,13 @@ class AsyncLLMEngine:
engine_class = ray.remote(num_gpus=1)(self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self):
"""Kick the engine to process the waiting requests."""
async def engine_step(self) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
new_requests, finished_requests = (
self.request_tracker.get_new_and_finished_requests())
self._request_tracker.get_new_and_finished_requests())
for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue.
@@ -304,9 +332,11 @@ class AsyncLLMEngine:
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
self.request_tracker.process_request_output(
self._request_tracker.process_request_output(
request_output, verbose=self.log_requests)
return len(request_outputs) > 0
async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
@@ -314,8 +344,12 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
# Initialize the RequestTracker here so it uses the right event loop.
has_requests_in_progress = False
while True:
await self.engine_step()
if not has_requests_in_progress:
await self._request_tracker.wait_for_new_requests()
has_requests_in_progress = await self.engine_step()
await asyncio.sleep(0)
async def add_request(
@@ -350,7 +384,7 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError).")
stream = self.request_tracker.add_request(
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
sampling_params=sampling_params,
@@ -428,8 +462,8 @@ class AsyncLLMEngine:
Args:
request_id: The unique id of the request.
"""
self.request_tracker.abort_request(request_id,
verbose=self.log_requests)
self._request_tracker.abort_request(request_id,
verbose=self.log_requests)
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""