[Frontend] Pass API server count to each process (#23717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -17,13 +17,14 @@ from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
from typing import Annotated, Any, Callable, Literal, Optional
|
||||
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
|
||||
Request)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
@@ -166,6 +167,9 @@ async def build_async_engine_client(
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
if client_config:
|
||||
engine_args._api_process_count = client_config.get("client_count", 1)
|
||||
engine_args._api_process_rank = client_config.get("client_index", 0)
|
||||
|
||||
if disable_frontend_multiprocessing is None:
|
||||
disable_frontend_multiprocessing = bool(
|
||||
@@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop("client_count") if client_config else 1
|
||||
client_index = client_config.pop("client_index") if client_config else 0
|
||||
|
||||
# Don't mutate the input client_config
|
||||
client_config = dict(client_config) if client_config else {}
|
||||
client_count = client_config.pop("client_count", 1)
|
||||
client_index = client_config.pop("client_index", 0)
|
||||
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
@@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!")
|
||||
|
||||
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||
|
||||
@router.get("/server_info")
|
||||
async def show_server_info(raw_request: Request):
|
||||
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
|
||||
async def show_server_info(
|
||||
raw_request: Request,
|
||||
config_format: Annotated[Literal["text", "json"],
|
||||
Query()] = "text",
|
||||
):
|
||||
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||
server_info = {
|
||||
"vllm_config":
|
||||
str(vllm_config)
|
||||
if config_format == "text" else PydanticVllmConfig.dump_python(
|
||||
vllm_config, mode="json", fallback=str)
|
||||
# fallback=str is needed to handle e.g. torch.dtype
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
@@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
@@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
logger.info("Starting vLLM API server %d on %s",
|
||||
vllm_config.parallel_config._api_process_rank,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
||||
Reference in New Issue
Block a user