[Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892)
Signed-off-by: Zuyi Zhao <zhaozuy@amazon.com> Signed-off-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
@@ -65,7 +66,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
IOProcessorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
PoolingResponse,
|
||||
@@ -82,7 +82,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
TranscriptionResponse,
|
||||
TranslationRequest,
|
||||
TranslationResponse,
|
||||
UnloadLoRAAdapterRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_classification import ServingClassification
|
||||
@@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request):
|
||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||
|
||||
|
||||
@router.get("/ping", response_class=Response)
|
||||
@router.post("/ping", response_class=Response)
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/tokenize",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@@ -1236,47 +1228,6 @@ INVOCATION_VALIDATORS = [
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invocations",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def invocations(raw_request: Request):
|
||||
"""For SageMaker, routes requests based on the request type."""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}"
|
||||
) from e
|
||||
|
||||
valid_endpoints = [
|
||||
(validator, endpoint)
|
||||
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
||||
if get_handler(raw_request) is not None
|
||||
]
|
||||
|
||||
for request_validator, endpoint in valid_endpoints:
|
||||
try:
|
||||
request = request_validator.validate_python(body)
|
||||
except pydantic.ValidationError:
|
||||
continue
|
||||
|
||||
return await endpoint(request, raw_request)
|
||||
|
||||
type_names = [
|
||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||
for validator, _ in valid_endpoints
|
||||
]
|
||||
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
||||
res = base(raw_request).create_error_response(message=msg)
|
||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@@ -1304,39 +1255,6 @@ if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
|
||||
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
|
||||
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
|
||||
handler = models(raw_request)
|
||||
response = await handler.load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post(
|
||||
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
|
||||
)
|
||||
async def unload_lora_adapter(
|
||||
request: UnloadLoRAAdapterRequest, raw_request: Request
|
||||
):
|
||||
handler = models(raw_request)
|
||||
response = await handler.unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=response.model_dump(), status_code=response.error.code
|
||||
)
|
||||
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
|
||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||
if not log_config_file:
|
||||
return None
|
||||
@@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
else:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
logger.warning(
|
||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||
"This should ONLY be used for local development!"
|
||||
)
|
||||
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
|
||||
|
||||
register_dynamic_lora_routes(router)
|
||||
|
||||
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||
|
||||
register_sagemaker_routes(router)
|
||||
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
@@ -1696,6 +1628,8 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
f"Invalid middleware {middleware}. Must be a function or a class."
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user