[BugFix] Fix clean shutdown issues (#8492)
This commit is contained in:
@@ -4,16 +4,20 @@ import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
async_engine_client: AsyncEngineClient
|
||||
engine_args: AsyncEngineArgs
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_embedding: OpenAIServingEmbedding
|
||||
openai_serving_tokenization: OpenAIServingTokenization
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
@@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
async_engine_client = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
|
||||
yield
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -103,16 +111,10 @@ async def build_async_engine_client(
|
||||
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
|
||||
async_engine_client = engine # type: ignore[assignment]
|
||||
yield engine
|
||||
|
||||
|
||||
@@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
|
||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||
engine_args.quantization, engine_args.revision)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
try:
|
||||
yield engine_client
|
||||
finally:
|
||||
engine_client.shutdown_background_loop()
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
@@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> AsyncEngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> Response:
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await async_engine_client.check_health()
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models():
|
||||
models = await openai_serving_completion.show_available_models()
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@@ -288,7 +320,7 @@ async def show_version():
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
@@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await openai_serving_embedding.create_embedding(
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
@@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile():
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await async_engine_client.start_profile()
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile():
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await async_engine_client.stop_profile()
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest):
|
||||
response = await openai_serving_chat.load_lora_adapter(request)
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await openai_serving_completion.load_lora_adapter(request)
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
|
||||
response = await openai_serving_chat.unload_lora_adapter(request)
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await openai_serving_completion.unload_lora_adapter(request)
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
def init_app_state(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
global openai_serving_chat
|
||||
global openai_serving_completion
|
||||
global openai_serving_embedding
|
||||
global openai_serving_tokenization
|
||||
state.engine_client = async_engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@@ -465,7 +496,7 @@ async def init_app(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@@ -474,13 +505,13 @@ async def init_app(
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@@ -488,25 +519,31 @@ async def init_app(
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
# If None, creation of the client failed and we exit.
|
||||
if async_engine_client is None:
|
||||
return
|
||||
|
||||
app = await init_app(async_engine_client, args)
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
init_app_state(async_engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
engine=async_engine_client,
|
||||
limit_concurrency=async_engine_client.limit_concurrency,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
@@ -530,4 +567,4 @@ if __name__ == "__main__":
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
Reference in New Issue
Block a user