[FEATURE] Enables /score endpoint for embedding models (#12846)

This commit is contained in:
Gabriel Marinho
2025-02-21 03:09:47 -03:00
committed by GitHub
parent 1cdc88614a
commit 1c3c975766
11 changed files with 590 additions and 513 deletions

View File

@@ -7,7 +7,6 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
import torch
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import TypeVar, deprecated
@@ -25,6 +24,8 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.logger import init_logger
@@ -1010,40 +1011,25 @@ class LLM:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[ScoringRequestOutput]:
encoded_output = self.encode(
encoded_output: List[PoolingRequestOutput] = self.encode(
text_1 + text_2,
use_tqdm=use_tqdm,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
encoded_output_1 = encoded_output[0:len(text_1)]
encoded_output_2 = encoded_output[len(text_1):]
encoded_output_1: List[PoolingRequestOutput] = encoded_output[
0:len(text_1)]
encoded_output_2: List[PoolingRequestOutput] = encoded_output[
len(text_1):]
if len(encoded_output_1) == 1:
encoded_output_1 = encoded_output_1 * len(encoded_output_2)
output_pairs = [(t1, t2)
for t1, t2 in zip(encoded_output_1, encoded_output_2)]
scores: List[PoolingRequestOutput] = []
scores = []
scorer = torch.nn.CosineSimilarity(0)
for embed_1, embed_2 in output_pairs:
pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
tokens = embed_1.prompt_token_ids + [
pad_token_id
] + embed_2.prompt_token_ids
else:
tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{embed_1.request_id}_{embed_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
scores = _cosine_similarity(tokenizer=tokenizer,
embed_1=encoded_output_1,
embed_2=encoded_output_2)
items = self.engine_class.validate_outputs(scores,
PoolingRequestOutput)
@@ -1183,12 +1169,7 @@ class LLM:
text_2 = [text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
_validate_score_input_lens(input_text_1, input_text_2)
if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, input_text_1,
@@ -1197,7 +1178,6 @@ class LLM:
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]

View File

@@ -73,8 +73,7 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
@@ -320,12 +319,12 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]:
def score(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
def rerank(request: Request) -> Optional[JinaAIServingRerank]:
return request.app.state.jinaai_serving_reranking
def rerank(request: Request) -> Optional[ServingScores]:
return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization:
@@ -866,13 +865,13 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embed" else None
state.openai_serving_scores = OpenAIServingScores(
state.openai_serving_scores = ServingScores(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.jinaai_serving_reranking = JinaAIServingRerank(
request_logger=request_logger) if model_config.task in (
"score", "embed", "pooling") else None
state.jinaai_serving_reranking = ServingScores(
engine_client,
model_config,
state.openai_serving_models,

View File

@@ -26,7 +26,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
@@ -342,7 +342,7 @@ async def main(args):
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embed" else None
openai_serving_scores = (OpenAIServingScores(
openai_serving_scores = (ServingScores(
engine,
model_config,
openai_serving_models,
@@ -364,9 +364,9 @@ async def main(args):
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion)
if handler_fn is None:
chat_handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion)
if chat_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
@@ -375,12 +375,13 @@ async def main(args):
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
response_futures.append(
run_request(chat_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding)
if handler_fn is None:
embed_handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding)
if embed_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
@@ -388,12 +389,13 @@ async def main(args):
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
response_futures.append(
run_request(embed_handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/score":
handler_fn = (None if openai_serving_scores is None else
openai_serving_scores.create_score)
if handler_fn is None:
score_handler_fn = (None if openai_serving_scores is None else
openai_serving_scores.create_score)
if score_handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
@@ -401,7 +403,8 @@ async def main(args):
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
response_futures.append(
run_request(score_handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(

View File

@@ -52,8 +52,8 @@ from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, ScoreRequest,
TokenizeCompletionRequest]
EmbeddingCompletionRequest, RerankRequest,
ScoreRequest, TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]

View File

@@ -1,208 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
RerankRequest, RerankResponse,
RerankResult, RerankUsage)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
class JinaAIServingRerank(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger)
async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"rerank-{self._base_request_id(raw_request)}"
truncate_prompt_tokens = request.truncate_prompt_tokens
query = request.query
documents = request.documents
request_prompts = []
engine_prompts = []
top_n = request.top_n if request.top_n > 0 else len(documents)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
for doc in documents:
request_prompt = f"{query}{tokenizer.sep_token}{doc}"
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=query,
text_pair=doc,
**tokenization_kwargs)
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_rerank_response(
final_res_batch_checked, request_id, model_name, documents,
top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def request_output_to_rerank_response(
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
model_name: str, documents: List[str],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: List[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))

View File

@@ -1,53 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
ScoreResponse, ScoreResponseData,
UsageInfo)
from vllm.entrypoints.openai.protocol import (ErrorResponse, RerankDocument,
RerankRequest, RerankResponse,
RerankResult, RerankUsage,
ScoreRequest, ScoreResponse,
ScoreResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast)
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
str]) -> List:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [t for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [t for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
class OpenAIServingScores(OpenAIServing):
class ServingScores(OpenAIServing):
def __init__(
self,
@@ -62,6 +45,248 @@ class OpenAIServingScores(OpenAIServing):
models=models,
request_logger=request_logger)
async def _embedding_score(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
texts_1: List[str],
texts_2: List[str],
request: Union[RerankRequest, ScoreRequest],
request_id=str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> List[PoolingRequestOutput]:
input_texts = texts_1 + texts_2
engine_prompts: List[TokensPrompt] = []
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
tokenization_kwargs = tokenization_kwargs or {}
tokenized_prompts = await asyncio.gather(
*(tokenize_async(t, **tokenization_kwargs) for t in input_texts))
for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = \
self._validate_input(
request,
tok_result["input_ids"],
input_text)
engine_prompts.append(
TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"]))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
input_texts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
generators.append(
self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
))
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: List[PoolingRequestOutput] = []
embeddings: List[Optional[PoolingRequestOutput]] =\
[None] * len(engine_prompts)
async for i, res in result_generator:
embeddings[i] = res
emb_texts_1: List[PoolingRequestOutput] = []
emb_texts_2: List[PoolingRequestOutput] = []
for i in range(0, len(texts_1)):
assert (emb := embeddings[i]) is not None
emb_texts_1.append(emb)
for i in range(len(texts_1), len(embeddings)):
assert (emb := embeddings[i]) is not None
emb_texts_2.append(emb)
if len(emb_texts_1) == 1:
emb_texts_1 = emb_texts_1 * len(emb_texts_2)
final_res_batch = _cosine_similarity(tokenizer=tokenizer,
embed_1=emb_texts_1,
embed_2=emb_texts_2)
return final_res_batch
async def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
texts_1: List[str],
texts_2: List[str],
request: Union[RerankRequest, ScoreRequest],
request_id=str,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[Union[LoRARequest, None]] = None,
prompt_adapter_request: Optional[Union[PromptAdapterRequest,
None]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> List[PoolingRequestOutput]:
request_prompts: List[str] = []
engine_prompts: List[TokensPrompt] = []
if len(texts_1) == 1:
texts_1 = texts_1 * len(texts_2)
input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)]
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
tokenization_kwargs = tokenization_kwargs or {}
tokenized_prompts = await asyncio.gather(
*(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs)
for t1, t2 in input_pairs))
for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs):
request_prompt = f"{t1}{tokenizer.sep_token}{t2}"
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
result_generator = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: List[
Optional[PoolingRequestOutput]] = [None] * len(engine_prompts)
async for i, res in result_generator:
final_res_batch[i] = res
return [out for out in final_res_batch if out is not None]
async def _run_scoring(
self,
texts_1: Union[str, list[str]],
texts_2: Union[str, list[str]],
request: Union[ScoreRequest, RerankRequest],
request_id: str,
raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> List[PoolingRequestOutput]:
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
if isinstance(texts_1, str):
texts_1 = [texts_1]
if isinstance(texts_2, str):
texts_2 = [texts_2]
_validate_score_input_lens(texts_1, texts_2)
if self.model_config.is_cross_encoder:
return await self._cross_encoding_score(
tokenizer=tokenizer,
texts_1=texts_1,
texts_2=texts_2,
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
else:
return await self._embedding_score(
tokenizer=tokenizer,
texts_1=texts_1,
texts_2=texts_2,
request=request,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers)
async def create_score(
self,
request: ScoreRequest,
@@ -76,123 +301,24 @@ class OpenAIServingScores(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
request_prompts = []
engine_prompts = []
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
final_res_batch = await self._run_scoring(
request.text_1,
request.text_2,
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for scoring models")
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
f"is greater than max_model_len ({self.max_model_len})."
f" Please, select a smaller truncation size.")
input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(*generators)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[PoolingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[PoolingRequestOutput],
final_res_batch)
response = self.request_output_to_score_response(
final_res_batch_checked,
return self.request_output_to_score_response(
final_res_batch,
request_id,
created_time,
model_name,
request.model,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
@@ -200,7 +326,44 @@ class OpenAIServingScores(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
async def do_rerank(
self,
request: RerankRequest,
raw_request: Optional[Request] = None
) -> Union[RerankResponse, ErrorResponse]:
"""
Rerank API based on JinaAI's rerank API; implements the same
API interface. Designed for compatibility with off-the-shelf
tooling, since this is a common standard for reranking APIs
See example client implementations at
https://github.com/infiniflow/ragflow/blob/main/rag/llm/rerank_model.py
numerous clients use this standard.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"rerank-{self._base_request_id(raw_request)}"
documents = request.documents
top_n = request.top_n if request.top_n > 0 else len(documents)
try:
final_res_batch = await self._run_scoring(
request.query,
documents,
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
return self.request_output_to_rerank_response(
final_res_batch, request_id, request.model, documents, top_n)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def request_output_to_score_response(
self,
@@ -236,3 +399,35 @@ class OpenAIServingScores(OpenAIServing):
data=items,
usage=usage,
)
def request_output_to_rerank_response(
self, final_res_batch: List[PoolingRequestOutput], request_id: str,
model_name: str, documents: List[str],
top_n: int) -> RerankResponse:
"""
Convert the output of do_rank to a RerankResponse
"""
results: List[RerankResult] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
classify_res = ScoringRequestOutput.from_base(final_res)
result = RerankResult(
index=idx,
document=RerankDocument(text=documents[idx]),
relevance_score=classify_res.outputs.score,
)
results.append(result)
prompt_token_ids = final_res.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
# sort by relevance, then return the top n if set
results.sort(key=lambda x: x.relevance_score, reverse=True)
if top_n < len(documents):
results = results[:top_n]
return RerankResponse(
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens))

View File

@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Union
from torch.nn import CosineSimilarity
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer,
PreTrainedTokenizerFast)
def _cosine_similarity(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embed_1: List[PoolingRequestOutput],
embed_2: List[PoolingRequestOutput],
) -> List[PoolingRequestOutput]:
scorer = CosineSimilarity(0)
scores: Union[List[PoolingRequestOutput]] = []
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)
padding = []
if (pad_token_id := getattr(tokenizer, "pad_token_id",
None)) is not None:
padding = [pad_token_id]
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
PoolingRequestOutput(
request_id=f"{emb_1.request_id}_{emb_2.request_id}",
outputs=pair_score,
prompt_token_ids=tokens,
finished=True))
return scores
def _validate_score_input_lens(
texts_1: Union[List[str], List[dict]],
texts_2: Union[List[str], List[dict]],
):
if len(texts_1) > 1 and len(texts_1) != len(texts_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(texts_1) == 0:
raise ValueError("At least one text element must be given")
if len(texts_2) == 0:
raise ValueError("At least one text_pair element must be given")