[Refactor] [4/N] Move VLLM_SERVER_DEV endpoints into the serve directory (#30749)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -17,21 +17,20 @@ from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import pydantic
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.datastructures import URL, Headers, MutableHeaders, State
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
@@ -639,97 +638,6 @@ async def create_translations(
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning(
|
||||
"SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!"
|
||||
)
|
||||
|
||||
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||
|
||||
@router.get("/server_info")
|
||||
async def show_server_info(
|
||||
raw_request: Request,
|
||||
config_format: Annotated[Literal["text", "json"], Query()] = "text",
|
||||
):
|
||||
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||
server_info = {
|
||||
"vllm_config": str(vllm_config)
|
||||
if config_format == "text"
|
||||
else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
|
||||
# fallback=str is needed to handle e.g. torch.dtype
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
async def reset_prefix_cache(
|
||||
raw_request: Request,
|
||||
reset_running_requests: bool = Query(default=False),
|
||||
reset_external: bool = Query(default=False),
|
||||
):
|
||||
"""
|
||||
Reset the local prefix cache.
|
||||
|
||||
Optionally, if the query parameter `reset_external=true`
|
||||
also resets the external (connector-managed) prefix cache.
|
||||
|
||||
Note that we currently do not check if the prefix cache
|
||||
is successfully reset in the API server.
|
||||
|
||||
Example:
|
||||
POST /reset_prefix_cache?reset_external=true
|
||||
"""
|
||||
logger.info("Resetting prefix cache...")
|
||||
|
||||
await engine_client(raw_request).reset_prefix_cache(
|
||||
reset_running_requests, reset_external
|
||||
)
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/reset_mm_cache")
|
||||
async def reset_mm_cache(raw_request: Request):
|
||||
"""
|
||||
Reset the multi-modal cache. Note that we currently do not check if the
|
||||
multi-modal cache is successfully reset in the API server.
|
||||
"""
|
||||
logger.info("Resetting multi-modal cache...")
|
||||
await engine_client(raw_request).reset_mm_cache()
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/collective_rpc")
|
||||
async def collective_rpc(raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
method = body.get("method")
|
||||
if method is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'method' in request body",
|
||||
)
|
||||
# For security reason, only serialized string args/kwargs are passed.
|
||||
# User-defined `method` is responsible for deserialization if needed.
|
||||
args: list[str] = body.get("args", [])
|
||||
kwargs: dict[str, str] = body.get("kwargs", {})
|
||||
timeout: float | None = body.get("timeout")
|
||||
results = await engine_client(raw_request).collective_rpc(
|
||||
method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
|
||||
)
|
||||
if results is None:
|
||||
return Response(status_code=200)
|
||||
response: list[Any] = []
|
||||
for result in results:
|
||||
if result is None or isinstance(result, dict | list):
|
||||
response.append(result)
|
||||
else:
|
||||
response.append(str(result))
|
||||
return JSONResponse(content={"results": response})
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user