[Frontend] Rerank API (Jina- and Cohere-compatible API) (#12376)

Signed-off-by: Kyle Mistele <kyle@mistele.com>
This commit is contained in:
Kyle Mistele
2025-01-26 20:58:45 -06:00
committed by GitHub
parent 72bac73067
commit 0034b09ceb
9 changed files with 552 additions and 11 deletions

View File

@@ -56,6 +56,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
@@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
return request.app.state.jinaai_serving_reranking
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
@@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/rerank")
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Rerank (Score) API")
generator = await handler.do_rerank(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, RerankResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@router.post("/v1/rerank")
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
logger.warning(
"To indicate that the rerank API is not part of the standard OpenAI"
" API, we have located it at `/rerank`. Please update your client"
"accordingly. (Note: Conforms to JinaAI rerank API)")
return await do_rerank(request, raw_request)
@router.post("/v2/rerank")
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
return await do_rerank(request, raw_request)
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
@@ -512,7 +552,10 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"default": (EmbeddingCompletionRequest, create_embedding),
},
"score": {
"default": (ScoreRequest, create_score),
"default": (RerankRequest, do_rerank)
},
"rerank": {
"default": (RerankRequest, do_rerank)
},
"reward": {
"messages": (PoolingChatRequest, create_pooling),
@@ -759,6 +802,12 @@ async def init_app_state(
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.jinaai_serving_reranking = JinaAIServingRerank(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,