Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -817,6 +817,128 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
/,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
the text_1 sentence will be replicated N times to pair with the text_2
|
||||
sentences. The input pairs are used to build a list of prompts for the
|
||||
cross encoder model. This class automatically batches the prompts,
|
||||
considering the memory constraint. For the best performance, put all
|
||||
of your texts into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
text_1: can be a single prompt or a list of prompts, in which
|
||||
case it has to have the same length as the text_2 list
|
||||
text_2: The texts to pair with the query to form the input
|
||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||
more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.score() is only supported for embedding models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support the cross encoding")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
def ensure_str(prompt: SingletonPrompt):
|
||||
if isinstance(prompt, dict):
|
||||
if "multi_modal_data" in prompt:
|
||||
raise ValueError("Multi-modal prompt is not "
|
||||
"supported for cross encoding")
|
||||
elif "prompt_token_ids" in prompt:
|
||||
prompt = tokenizer.decode(
|
||||
cast(TokensPrompt, prompt)["prompt_token_ids"])
|
||||
elif "prompt" in prompt:
|
||||
prompt = cast(TextPrompt, prompt)["prompt"]
|
||||
assert type(prompt) is str
|
||||
return prompt
|
||||
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
parsed_prompts = []
|
||||
|
||||
for q, t in input_pairs:
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
parsed_prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse,
|
||||
UnloadLoraAdapterRequest)
|
||||
@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
@@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def score(request: Request) -> Optional[OpenAIServingScores]:
|
||||
return request.app.state.openai_serving_scores
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
@@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/v1/score")
|
||||
async def create_score(request: ScoreRequest, raw_request: Request):
|
||||
handler = score(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support Score API")
|
||||
|
||||
generator = await handler.create_score(request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
elif isinstance(generator, ScoreResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
err = ErrorResponse(message=str(exc),
|
||||
type="BadRequestError",
|
||||
code=HTTPStatus.BAD_REQUEST)
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@@ -565,6 +589,13 @@ def init_app_state(
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embedding" else None
|
||||
state.openai_serving_scores = OpenAIServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger
|
||||
) if (model_config.task == "embedding" \
|
||||
and model_config.is_cross_encoder) else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
model_config,
|
||||
|
||||
@@ -806,6 +806,27 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
text_1: Union[List[str], str]
|
||||
text_2: Union[List[str], str]
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# doc: begin-chat-embedding-pooling-params
|
||||
additional_data: Optional[Any] = None
|
||||
# doc: end-chat-embedding-pooling-params
|
||||
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
@@ -876,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ScoreResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "score"
|
||||
score: Union[List[float], str]
|
||||
|
||||
|
||||
class ScoreResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
data: List[ScoreResponseData]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
215
vllm/entrypoints/openai/serving_score.py
Normal file
215
vllm/entrypoints/openai/serving_score.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
|
||||
ScoreResponse, ScoreResponseData,
|
||||
UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def request_output_to_score_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str) -> ScoreResponse:
|
||||
data: List[ScoreResponseData] = []
|
||||
score = None
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
if final_res is not None:
|
||||
score = final_res.outputs.embedding
|
||||
score_data = ScoreResponseData(index=idx, score=score)
|
||||
data.append(score_data)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
total_tokens=num_prompt_tokens,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
data=data,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
|
||||
str]) -> List:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [t for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [t for t in text_2]
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
|
||||
|
||||
class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def create_score(
|
||||
self,
|
||||
request: ScoreRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[ScoreResponse, ErrorResponse]:
|
||||
"""
|
||||
Score API similar to Sentence Transformers cross encoder
|
||||
|
||||
See https://sbert.net/docs/package_reference/cross_encoder
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"score-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
request_prompts = []
|
||||
engine_prompts = []
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
if not self.model_config.is_cross_encoder:
|
||||
raise ValueError("Model is not cross encoder.")
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
|
||||
|
||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||
|
||||
for q, t in input_pairs:
|
||||
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
|
||||
request_prompts.append(request_prompt)
|
||||
engine_prompts.append(engine_prompt)
|
||||
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
request_prompts[i],
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
trace_headers = (None if raw_request is None else await
|
||||
self._get_trace_headers(raw_request.headers))
|
||||
|
||||
generator = self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * num_prompts
|
||||
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
final_res_batch[i] = res
|
||||
|
||||
assert all(final_res is not None for final_res in final_res_batch)
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_score_response(
|
||||
final_res_batch_checked, request_id, created_time, model_name)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user