[Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892)

Signed-off-by: Zuyi Zhao <zhaozuy@amazon.com>
Signed-off-by: Shen Teng <sheteng@amazon.com>
Co-authored-by: Shen Teng <sheteng@amazon.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Zuyi Zhao
2025-11-10 20:57:01 -08:00
committed by GitHub
parent 8d706cca90
commit bca74e32b7
11 changed files with 1613 additions and 83 deletions

View File

@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Annotated, Any, Literal
import model_hosting_container_standards.sagemaker as sagemaker_standards
import prometheus_client
import pydantic
import regex as re
@@ -65,7 +66,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorInfo,
ErrorResponse,
IOProcessorResponse,
LoadLoRAAdapterRequest,
PoolingBytesResponse,
PoolingRequest,
PoolingResponse,
@@ -82,7 +82,6 @@ from vllm.entrypoints.openai.protocol import (
TranscriptionResponse,
TranslationRequest,
TranslationResponse,
UnloadLoRAAdapterRequest,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_classification import ServingClassification
@@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request):
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post(
"/tokenize",
dependencies=[Depends(validate_json_request)],
@@ -1236,47 +1228,6 @@ INVOCATION_VALIDATORS = [
]
@router.post(
"/invocations",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def invocations(raw_request: Request):
"""For SageMaker, routes requests based on the request type."""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}"
) from e
valid_endpoints = [
(validator, endpoint)
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
if get_handler(raw_request) is not None
]
for request_validator, endpoint in valid_endpoints:
try:
request = request_validator.validate_python(body)
except pydantic.ValidationError:
continue
return await endpoint(request, raw_request)
type_names = [
t.__name__ if isinstance(t := validator._type, type) else str(t)
for validator, _ in valid_endpoints
]
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
res = base(raw_request).create_error_response(message=msg)
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
@@ -1304,39 +1255,6 @@ if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
return Response(status_code=200)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
return Response(status_code=200, content=response)
@router.post(
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
)
async def unload_lora_adapter(
request: UnloadLoRAAdapterRequest, raw_request: Request
):
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(
content=response.model_dump(), status_code=response.error.code
)
return Response(status_code=200, content=response)
def load_log_config(log_config_file: str | None) -> dict | None:
if not log_config_file:
return None
@@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI:
)
else:
app = FastAPI(lifespan=lifespan)
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
register_dynamic_lora_routes(router)
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
register_sagemaker_routes(router)
app.include_router(router)
app.root_path = args.root_path
@@ -1696,6 +1628,8 @@ def build_app(args: Namespace) -> FastAPI:
f"Invalid middleware {middleware}. Must be a function or a class."
)
app = sagemaker_standards.bootstrap(app)
return app