[Frontend] Rerank API (Jina- and Cohere-compatible API) (#12376)
Signed-off-by: Kyle Mistele <kyle@mistele.com>
This commit is contained in:
@@ -56,6 +56,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
@@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
|
||||
return request.app.state.jinaai_serving_reranking
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API")
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client"
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
@@ -512,7 +552,10 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (ScoreRequest, create_score),
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
@@ -759,6 +802,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.jinaai_serving_reranking = JinaAIServingRerank(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
Reference in New Issue
Block a user