[Frontend] Factor out code for running uvicorn (#6828)

This commit is contained in:
Cyrus Leung
2024-07-27 09:58:25 +08:00
committed by GitHub
parent d09b94ca58
commit 981b0d5673
4 changed files with 116 additions and 75 deletions

View File

@@ -2,14 +2,12 @@ import asyncio
import importlib
import inspect
import re
import signal
from argparse import Namespace
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Optional, Set
from typing import Any, Optional, Set
import fastapi
import uvicorn
from fastapi import APIRouter, Request
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -38,6 +36,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.server import serve_http
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
@@ -57,7 +56,7 @@ _running_tasks: Set[asyncio.Task] = set()
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def lifespan(app: FastAPI):
async def _force_log():
while True:
@@ -75,7 +74,7 @@ async def lifespan(app: fastapi.FastAPI):
router = APIRouter()
def mount_metrics(app: fastapi.FastAPI):
def mount_metrics(app: FastAPI):
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
@@ -165,8 +164,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
def build_app(args):
app = fastapi.FastAPI(lifespan=lifespan)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
@@ -214,11 +213,8 @@ def build_app(args):
return app
async def build_server(
args,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs,
) -> uvicorn.Server:
async def init_app(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None) -> FastAPI:
app = build_app(args)
if args.served_model_name is not None:
@@ -281,14 +277,17 @@ async def build_server(
)
app.root_path = args.root_path
logger.info("Available routes are:")
for route in app.routes:
if not hasattr(route, 'methods'):
continue
methods = ', '.join(route.methods)
logger.info("Route: %s, Methods: %s", route.path, methods)
return app
config = uvicorn.Config(
async def run_server(args: Namespace,
llm_engine: Optional[AsyncLLMEngine] = None,
**uvicorn_kwargs: Any) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
app = await init_app(args, llm_engine)
await serve_http(
app,
host=args.host,
port=args.port,
@@ -301,36 +300,6 @@ async def build_server(
**uvicorn_kwargs,
)
return uvicorn.Server(config)
async def run_server(args, llm_engine=None, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
server = await build_server(
args,
llm_engine,
**uvicorn_kwargs,
)
loop = asyncio.get_running_loop()
server_task = loop.create_task(server.serve())
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
loop.add_signal_handler(signal.SIGINT, signal_handler)
loop.add_signal_handler(signal.SIGTERM, signal_handler)
try:
await server_task
except asyncio.CancelledError:
print("Gracefully stopping http server")
await server.shutdown()
if __name__ == "__main__":
# NOTE(simon):
@@ -339,4 +308,5 @@ if __name__ == "__main__":
description="vLLM OpenAI-Compatible RESTful API server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
asyncio.run(run_server(args))