[Frontend] Pass API server count to each process (#23717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-20 01:15:19 +08:00
committed by GitHub
parent 7ac67ea525
commit 6c117cff7d
12 changed files with 221 additions and 51 deletions

View File

@@ -17,13 +17,14 @@ from argparse import Namespace
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Annotated, Any, Callable, Optional
from typing import Annotated, Any, Callable, Literal, Optional
import prometheus_client
import pydantic
import regex as re
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
Request)
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@@ -166,6 +167,9 @@ async def build_async_engine_client(
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
if client_config:
engine_args._api_process_count = client_config.get("client_count", 1)
engine_args._api_process_rank = client_config.get("client_index", 0)
if disable_frontend_multiprocessing is None:
disable_frontend_multiprocessing = bool(
@@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
from vllm.v1.engine.async_llm import AsyncLLM
async_llm: Optional[AsyncLLM] = None
client_count = client_config.pop("client_count") if client_config else 1
client_index = client_config.pop("client_index") if client_config else 0
# Don't mutate the input client_config
client_config = dict(client_config) if client_config else {}
client_count = client_config.pop("client_count", 1)
client_index = client_config.pop("client_index", 0)
try:
async_llm = AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
@@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!")
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
@router.get("/server_info")
async def show_server_info(raw_request: Request):
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
async def show_server_info(
raw_request: Request,
config_format: Annotated[Literal["text", "json"],
Query()] = "text",
):
vllm_config: VllmConfig = raw_request.app.state.vllm_config
server_info = {
"vllm_config":
str(vllm_config)
if config_format == "text" else PydanticVllmConfig.dump_python(
vllm_config, mode="json", fallback=str)
# fallback=str is needed to handle e.g. torch.dtype
}
return JSONResponse(content=server_info)
@router.post("/reset_prefix_cache")
@@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
@@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
logger.info("Starting vLLM API server %d on %s", server_index,
logger.info("Starting vLLM API server %d on %s",
vllm_config.parallel_config._api_process_rank,
listen_address)
shutdown_task = await serve_http(
app,