[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Robert Shaw
2024-08-02 21:27:28 -04:00
committed by GitHub
parent 708989341e
commit ed812a73fa
20 changed files with 1567 additions and 101 deletions

View File

@@ -8,7 +8,7 @@ from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
@@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing):
def __init__(
self,
engine: AsyncLLMEngine,
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
@@ -50,7 +50,7 @@ class OpenAIServingChat(OpenAIServing):
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine=engine,
super().__init__(async_engine_client=async_engine_client,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
@@ -89,7 +89,8 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
tokenizer = await self.async_engine_client.get_tokenizer(
lora_request)
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
@@ -161,7 +162,8 @@ class OpenAIServingChat(OpenAIServing):
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
is_tracing_enabled = (
await self.async_engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
@@ -169,7 +171,7 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
result_generator = self.async_engine_client.generate(
engine_inputs,
sampling_params,
request_id,
@@ -441,7 +443,7 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator:
if raw_request is not None and await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
await self.async_engine_client.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None