[Frontend] Rerank API (Jina- and Cohere-compatible API) (#12376)
Signed-off-by: Kyle Mistele <kyle@mistele.com>
This commit is contained in:
@@ -56,6 +56,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
@@ -68,6 +69,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
@@ -306,6 +308,10 @@ def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
|
||||
return request.app.state.jinaai_serving_reranking
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@@ -502,6 +508,40 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
|
||||
return await create_score(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank(request: RerankRequest, raw_request: Request):
|
||||
handler = rerank(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Rerank (Score) API")
|
||||
generator = await handler.do_rerank(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, RerankResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
|
||||
logger.warning(
|
||||
"To indicate that the rerank API is not part of the standard OpenAI"
|
||||
" API, we have located it at `/rerank`. Please update your client"
|
||||
"accordingly. (Note: Conforms to JinaAI rerank API)")
|
||||
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
@router.post("/v2/rerank")
|
||||
@with_cancellation
|
||||
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
|
||||
return await do_rerank(request, raw_request)
|
||||
|
||||
|
||||
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"generate": {
|
||||
"messages": (ChatCompletionRequest, create_chat_completion),
|
||||
@@ -512,7 +552,10 @@ TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
|
||||
"default": (EmbeddingCompletionRequest, create_embedding),
|
||||
},
|
||||
"score": {
|
||||
"default": (ScoreRequest, create_score),
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"rerank": {
|
||||
"default": (RerankRequest, do_rerank)
|
||||
},
|
||||
"reward": {
|
||||
"messages": (PoolingChatRequest, create_pooling),
|
||||
@@ -759,6 +802,12 @@ async def init_app_state(
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.jinaai_serving_reranking = JinaAIServingRerank(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger
|
||||
) if model_config.task == "score" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
@@ -1018,6 +1018,52 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
query: str
|
||||
documents: List[str]
|
||||
top_n: int = Field(default_factory=lambda: 0)
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-rerank-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-rerank-pooling-params
|
||||
|
||||
# doc: begin-rerank-extra-params
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RerankResult(BaseModel):
|
||||
index: int
|
||||
document: RerankDocument
|
||||
relevance_score: float
|
||||
|
||||
|
||||
class RerankUsage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class RerankResponse(OpenAIBaseModel):
|
||||
id: str
|
||||
model: str
|
||||
usage: RerankUsage
|
||||
results: List[RerankResult]
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
|
||||
@@ -26,7 +26,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, ScoreRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@@ -204,9 +205,9 @@ class OpenAIServing:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(
|
||||
request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
|
||||
206
vllm/entrypoints/openai/serving_rerank.py
Normal file
206
vllm/entrypoints/openai/serving_rerank.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
|
||||
RerankRequest, RerankResponse,
|
||||
RerankResult, RerankUsage)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import make_async, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JinaAIServingRerank(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
models=models,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def do_rerank(
|
||||
self,
|
||||
request: RerankRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[RerankResponse, ErrorResponse]:
|
||||
"""
|
||||
Rerank API based on JinaAI's rerank API; implements the same
|
||||
API interface. Designed for compatibility with off-the-shelf
|
||||
tooling, since this is a common standard for reranking APIs
|
||||
|
||||
See example client implementations at
|
||||
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
|
||||
numerous clients use this standard.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"rerank-{self._base_request_id(raw_request)}"
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
query = request.query
|
||||
documents = request.documents
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
top_n = request.top_n if request.top_n > 0 else len(documents)
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for scoring models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||
f"is greater than max_model_len ({self.max_model_len})."
|
||||
f" Please, select a smaller truncation size.")
|
||||
for doc in documents:
|
||||
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=query,
|
||||
text_pair=doc,
|
||||
**tokenization_kwargs)
|
||||
|
||||
input_ids = prompt_inputs["input_ids"]
|
||||
text_token_prompt = \
|
||||
self._validate_input(request, input_ids, request_prompt)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[PoolingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[PoolingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_rerank_response(
|
||||
final_res_batch_checked, request_id, model_name, documents,
|
||||
top_n)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
def request_output_to_rerank_response(
|
||||
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
|
||||
model_name: str, documents: List[str],
|
||||
top_n: int) -> RerankResponse:
|
||||
"""
|
||||
Convert the output of do_rank to a RerankResponse
|
||||
"""
|
||||
results: List[RerankResult] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
classify_res = ScoringRequestOutput.from_base(final_res)
|
||||
|
||||
result = RerankResult(
|
||||
index=idx,
|
||||
document=RerankDocument(text=documents[idx]),
|
||||
relevance_score=classify_res.outputs.score,
|
||||
)
|
||||
results.append(result)
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
|
||||
# sort by relevance, then return the top n if set
|
||||
results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
if top_n < len(documents):
|
||||
results = results[:top_n]
|
||||
|
||||
return RerankResponse(
|
||||
id=request_id,
|
||||
model=model_name,
|
||||
results=results,
|
||||
usage=RerankUsage(total_tokens=num_prompt_tokens))
|
||||
Reference in New Issue
Block a user