Support Cross encoder models (#10400)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-24 23:56:20 -03:00
committed by GitHub
parent 49628fe13e
commit 214efc2c3c
28 changed files with 1370 additions and 62 deletions

View File

@@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -817,6 +817,128 @@ class LLM:
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def score(
self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
/,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
the text_1 sentence will be replicated N times to pair with the text_2
sentences. The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all
of your texts into a single list and pass it to this method.
Args:
text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the text_2 list
text_2: The texts to pair with the query to form the input
to the LLM. See :class:`~vllm.inputs.PromptType` for
more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.score() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support the cross encoding")
tokenizer = self.llm_engine.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
elif "prompt" in prompt:
prompt = cast(TextPrompt, prompt)["prompt"]
assert type(prompt) is str
return prompt
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()

View File

@@ -45,6 +45,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
LoadLoraAdapterRequest,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def score(request: Request) -> Optional[OpenAIServingScores]:
return request.app.state.openai_serving_scores
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
@@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
assert_never(generator)
@router.post("/v1/score")
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Score API")
generator = await handler.create_score(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, ScoreResponse):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning(
"Torch Profiler is enabled in the API server. This should ONLY be "
@@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
err = ErrorResponse(message=str(exc),
type="BadRequestError",
code=HTTPStatus.BAD_REQUEST)
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
@@ -565,6 +589,13 @@ def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embedding" else None
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger
) if (model_config.task == "embedding" \
and model_config.is_cross_encoder) else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,

View File

@@ -806,6 +806,27 @@ class EmbeddingChatRequest(OpenAIBaseModel):
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
class ScoreRequest(OpenAIBaseModel):
model: str
text_1: Union[List[str], str]
text_2: Union[List[str], str]
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-chat-embedding-pooling-params
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
@@ -876,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel):
usage: UsageInfo
class ScoreResponseData(OpenAIBaseModel):
index: int
object: str = "score"
score: Union[List[float], str]
class ScoreResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
data: List[ScoreResponseData]
usage: UsageInfo
class FunctionCall(OpenAIBaseModel):
name: str
arguments: str

View File

@@ -0,0 +1,215 @@
import asyncio
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
ScoreResponse, ScoreResponseData,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
def request_output_to_score_response(
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
created_time: int, model_name: str) -> ScoreResponse:
data: List[ScoreResponseData] = []
score = None
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
if final_res is not None:
score = final_res.outputs.embedding
score_data = ScoreResponseData(index=idx, score=score)
data.append(score_data)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return ScoreResponse(
id=request_id,
created=created_time,
model=model_name,
data=data,
usage=usage,
)
def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str],
str]) -> List:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [t for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [t for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
return [(t1, t2) for t1, t2 in zip(text_1, text_2)]
class OpenAIServingScores(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
async def create_score(
self,
request: ScoreRequest,
raw_request: Optional[Request] = None,
) -> Union[ScoreResponse, ErrorResponse]:
"""
Score API similar to Sentence Transformers cross encoder
See https://sbert.net/docs/package_reference/cross_encoder
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
model_name = request.model
request_id = f"score-{random_uuid()}"
created_time = int(time.monotonic())
truncate_prompt_tokens = request.truncate_prompt_tokens
request_prompts = []
engine_prompts = []
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.")
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}"
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt)
try:
pooling_params = request.to_pooling_params()
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator = merge_async_iterators(
*generators,
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
assert all(final_res is not None for final_res in final_res_batch)
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch)
response = request_output_to_score_response(
final_res_batch_checked, request_id, created_time, model_name)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response