[Frontend] Add /classify endpoint (#17032)
Signed-off-by: Frieda (Jingying) Huang <jingyingfhuang@gmail.com>
This commit is contained in:
@@ -48,6 +48,8 @@ from vllm.entrypoints.openai.cli_args import (log_non_default_args,
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ClassificationRequest,
|
||||
ClassificationResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
@@ -71,6 +73,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
UnloadLoRAAdapterRequest)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_classification import (
|
||||
ServingClassification)
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
@@ -373,6 +377,10 @@ def score(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def classify(request: Request) -> Optional[ServingClassification]:
|
||||
return request.app.state.openai_serving_classification
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[ServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
@@ -405,6 +413,7 @@ async def get_server_load_metrics(request: Request):
|
||||
# - /v1/audio/transcriptions
|
||||
# - /v1/embeddings
|
||||
# - /pooling
|
||||
# - /classify
|
||||
# - /score
|
||||
# - /v1/score
|
||||
# - /rerank
|
||||
@@ -572,6 +581,27 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/classify", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_classify(request: ClassificationRequest,
|
||||
raw_request: Request):
|
||||
handler = classify(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Classification API")
|
||||
|
||||
generator = await handler.create_classify(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
|
||||
elif isinstance(generator, ClassificationResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/score", dependencies=[Depends(validate_json_request)])
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
@@ -1001,6 +1031,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if model_config.task in (
|
||||
"score", "embed", "pooling") else None
|
||||
state.openai_serving_classification = ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "classify" else None
|
||||
state.jinaai_serving_reranking = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
Reference in New Issue
Block a user