[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Robert Shaw
2024-08-02 21:27:28 -04:00
committed by GitHub
parent 708989341e
commit ed812a73fa
20 changed files with 1567 additions and 101 deletions

View File

@@ -5,7 +5,8 @@ import re
import signal
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Optional, Set
from multiprocessing import Process
from typing import AsyncIterator, Set
import fastapi
import uvicorn
@@ -17,8 +18,10 @@ from prometheus_client import make_asgi_app
from starlette.routing import Mount
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
@@ -31,6 +34,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -39,12 +44,12 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.utils import FlexibleArgumentParser, get_open_port
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
engine: AsyncLLMEngine
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
@@ -56,13 +61,22 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str) -> bool:
return ModelConfig(model=model_name,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16").embedding_mode
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
@@ -72,6 +86,52 @@ async def lifespan(app: fastapi.FastAPI):
yield
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(args.model)
or args.disable_frontend_multiprocessing):
async_engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
yield async_engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
else:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port = get_open_port(envs.VLLM_RPC_PORT)
rpc_server_process = Process(target=run_rpc_server,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
port))
rpc_server_process.start()
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client = AsyncEngineRPCClient(port)
await async_engine_client.setup()
try:
yield async_engine_client
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
# Close all open connections to the backend
async_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
router = APIRouter()
@@ -86,7 +146,7 @@ def mount_metrics(app: fastapi.FastAPI):
@router.get("/health")
async def health() -> Response:
"""Health check."""
await openai_serving_chat.engine.check_health()
await async_engine_client.check_health()
return Response(status_code=200)
@@ -215,8 +275,8 @@ def build_app(args):
async def build_server(
async_engine_client: AsyncEngineClient,
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
app = build_app(args)
@@ -226,14 +286,7 @@ async def build_server(
else:
served_model_names = [args.model]
global engine, engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = (llm_engine
if llm_engine is not None else AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER))
model_config = await engine.get_model_config()
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
@@ -246,7 +299,7 @@ async def build_server(
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(
engine,
async_engine_client,
model_config,
served_model_names,
args.response_role,
@@ -257,7 +310,7 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_completion = OpenAIServingCompletion(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@@ -266,13 +319,13 @@ async def build_server(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
async_engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine,
async_engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@@ -304,32 +357,39 @@ async def build_server(
return uvicorn.Server(config)
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)
shutdown_task = None
async with build_async_engine_client(args) as async_engine_client:
loop = asyncio.get_running_loop()
server = await build_server(
async_engine_client,
args,
**uvicorn_kwargs,
)
server_task = loop.create_task(server.serve())
loop = asyncio.get_running_loop()
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
server_task = loop.create_task(server.serve())
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
shutdown_task = server.shutdown()
if shutdown_task:
# NB: Await server shutdown only after the backend context is exited
await shutdown_task
if __name__ == "__main__":