Fix/async chat serving (#2727)

This commit is contained in:
Sebastian Schoennenbeck
2024-05-03 20:04:14 +02:00
committed by GitHub
parent 7e65477e5e
commit f8e7adda21
5 changed files with 73 additions and 21 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import json
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from pydantic import Field
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -29,8 +29,11 @@ class LoRAModulePath:
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
def __init__(self,
engine: AsyncLLMEngine,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
await_post_init: Optional[Awaitable[Any]] = None):
self.engine = engine
self.served_model_names = served_model_names
if lora_modules is None:
@@ -56,12 +59,12 @@ class OpenAIServing:
if event_loop is not None and event_loop.is_running():
# If the current is instanced by Ray Serve,
# there is already a running event loop
event_loop.create_task(self._post_init())
event_loop.create_task(self._post_init(await_post_init))
else:
# When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
asyncio.run(self._post_init(await_post_init))
async def _post_init(self):
async def _post_init(self, await_post_init):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
@@ -73,6 +76,9 @@ class OpenAIServing:
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
if await_post_init is not None:
await await_post_init
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [