[Frontend] Add /classify endpoint (#17032)

Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
This commit is contained in:
Frieda Huang
2025-05-11 03:57:07 -04:00
committed by GitHub
parent d1110f5b5a
commit 9cea90eab4
9 changed files with 972 additions and 173 deletions

View File

@@ -48,6 +48,8 @@ from vllm.entrypoints.openai.cli_args import (log_non_default_args,
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
@@ -71,6 +73,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
UnloadLoRAAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_classification import (
ServingClassification)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import OpenAIServing
@@ -373,6 +377,10 @@ def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
def classify(request: Request) -> Optional[ServingClassification]:
return request.app.state.openai_serving_classification
def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
@@ -405,6 +413,7 @@ async def get_server_load_metrics(request: Request):
# - /v1/audio/transcriptions
# - /v1/embeddings
# - /pooling
# - /classify
# - /score
# - /v1/score
# - /rerank
@@ -572,6 +581,27 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
assert_never(generator)
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest,
raw_request: Request):
handler = classify(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Classification API")
generator = await handler.create_classify(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ClassificationResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/score", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
@@ -1001,6 +1031,12 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.openai_serving_classification = ServingClassification(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.task == "classify" else None
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,