[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <joe@joerun.de> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -5,7 +5,8 @@ import re
|
||||
import signal
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Set
|
||||
from multiprocessing import Process
|
||||
from typing import AsyncIterator, Set
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
@@ -17,8 +18,10 @@ from prometheus_client import make_asgi_app
|
||||
from starlette.routing import Mount
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
# yapf conflicts with isort for this block
|
||||
@@ -31,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@@ -39,12 +44,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
engine: AsyncLLMEngine
|
||||
async_engine_client: AsyncEngineClient
|
||||
engine_args: AsyncEngineArgs
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
@@ -56,13 +61,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
|
||||
_running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16").embedding_mode
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await engine.do_log_stats()
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
@@ -72,6 +86,52 @@ async def lifespan(app: fastapi.FastAPI):
|
||||
yield
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(args.model)
|
||||
or args.disable_frontend_multiprocessing):
|
||||
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
yield async_engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
else:
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
port = get_open_port(envs.VLLM_RPC_PORT)
|
||||
rpc_server_process = Process(target=run_rpc_server,
|
||||
args=(engine_args,
|
||||
UsageContext.OPENAI_API_SERVER,
|
||||
port))
|
||||
rpc_server_process.start()
|
||||
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(port)
|
||||
await async_engine_client.setup()
|
||||
|
||||
try:
|
||||
yield async_engine_client
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
async_engine_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI):
|
||||
@router.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
await openai_serving_chat.engine.check_health()
|
||||
await async_engine_client.check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@@ -215,8 +275,8 @@ def build_app(args):
|
||||
|
||||
|
||||
async def build_server(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
args,
|
||||
llm_engine: Optional[AsyncLLMEngine] = None,
|
||||
**uvicorn_kwargs,
|
||||
) -> uvicorn.Server:
|
||||
app = build_app(args)
|
||||
@@ -226,14 +286,7 @@ async def build_server(
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
global engine, engine_args
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = (llm_engine
|
||||
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
@@ -246,7 +299,7 @@ async def build_server(
|
||||
global openai_serving_tokenization
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
@@ -257,7 +310,7 @@ async def build_server(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@@ -266,13 +319,13 @@ async def build_server(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine,
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
@@ -304,32 +357,39 @@ async def build_server(
|
||||
return uvicorn.Server(config)
|
||||
|
||||
|
||||
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
server = await build_server(
|
||||
args,
|
||||
llm_engine,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
shutdown_task = None
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
server = await build_server(
|
||||
async_engine_client,
|
||||
args,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
server_task = loop.create_task(server.serve())
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
server_task = loop.create_task(server.serve())
|
||||
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
def signal_handler() -> None:
|
||||
# prevents the uvicorn signal handler to exit early
|
||||
server_task.cancel()
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
print("Gracefully stopping http server")
|
||||
await server.shutdown()
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler)
|
||||
|
||||
try:
|
||||
await server_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Gracefully stopping http server")
|
||||
shutdown_task = server.shutdown()
|
||||
|
||||
if shutdown_task:
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
await shutdown_task
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user