[Misc] Minimum requirements for SageMaker compatibility (#11576)
This commit is contained in:
@@ -16,7 +16,7 @@ from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set, Tuple
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi import APIRouter, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
@@ -44,11 +44,15 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData,
|
||||
ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
@@ -310,6 +314,12 @@ async def health(raw_request: Request) -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.api_route("/ping", methods=["GET", "POST"])
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
@with_cancellation
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
@@ -483,6 +493,54 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
"default": (CompletionRequest, create_completion),
|
||||
},
|
||||
"embed": {
|
||||
"messages": (EmbeddingChatRequest, create_embedding),
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (ScoreRequest, create_score),
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
"classify": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
"default": (PoolingCompletionRequest, create_pooling),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@router.post("/invocations")
|
||||
async def invocations(raw_request: Request):
|
||||
"""
|
||||
For SageMaker, routes requests to other handlers based on model `task`.
|
||||
"""
|
||||
body = await raw_request.json()
|
||||
task = raw_request.app.state.task
|
||||
|
||||
if task not in TASK_HANDLERS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported task: '{task}' for '/invocations'. "
|
||||
f"Expected one of {set(TASK_HANDLERS.keys())}")
|
||||
|
||||
handler_config = TASK_HANDLERS[task]
|
||||
if "messages" in body:
|
||||
request_model, handler = handler_config["messages"]
|
||||
else:
|
||||
request_model, handler = handler_config["default"]
|
||||
|
||||
# this is required since we lose the FastAPI automatic casting
|
||||
request = request_model.model_validate(body)
|
||||
return await handler(request, raw_request)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@@ -687,6 +745,7 @@ def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
)
|
||||
state.task = model_config.task
|
||||
|
||||
|
||||
def create_server_socket(addr: Tuple[str, int]) -> socket.socket:
|
||||
|
||||
Reference in New Issue
Block a user