[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <joe@joerun.de> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -8,7 +8,7 @@ from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
@@ -39,7 +39,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
@@ -50,7 +50,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
super().__init__(async_engine_client=async_engine_client,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
@@ -89,7 +89,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
tokenizer = await self.async_engine_client.get_tokenizer(
|
||||
lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
mm_futures: List[Awaitable[MultiModalDataDict]] = []
|
||||
@@ -161,7 +162,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
is_tracing_enabled = (
|
||||
await self.async_engine_client.is_tracing_enabled())
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
@@ -169,7 +171,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
result_generator = self.async_engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
@@ -441,7 +443,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async for res in result_generator:
|
||||
if raw_request is not None and await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
await self.async_engine_client.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
Reference in New Issue
Block a user