Support Cross encoder models (#10400)

Signed-off-by: Max de Bayser <maxdebayser@gmail.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
Maximilien de Bayser
2024-11-24 23:56:20 -03:00
committed by GitHub
parent 49628fe13e
commit 214efc2c3c
28 changed files with 1370 additions and 62 deletions

View File

@@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -817,6 +817,128 @@ class LLM:
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def score(
self,
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
/,
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates similarity scores for all pairs <text,text_pair>.
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
the text_1 sentence will be replicated N times to pair with the text_2
sentences. The input pairs are used to build a list of prompts for the
cross encoder model. This class automatically batches the prompts,
considering the memory constraint. For the best performance, put all
of your texts into a single list and pass it to this method.
Args:
text_1: can be a single prompt or a list of prompts, in which
case it has to have the same length as the text_2 list
text_2: The texts to pair with the query to form the input
to the LLM. See :class:`~vllm.inputs.PromptType` for
more details about the format of each prompts.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for
generation, if any.
Returns:
A list of ``EmbeddingRequestOutput`` objects containing the
generated scores in the same order as the input prompts.
"""
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.score() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if not self.llm_engine.model_config.is_cross_encoder:
raise ValueError("Your model does not support the cross encoding")
tokenizer = self.llm_engine.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
raise ValueError(
"MistralTokenizer not supported for cross-encoding")
# the tokenizer for models such as
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
# lists of tokens to the `text` and `text_pair` kwargs
def ensure_str(prompt: SingletonPrompt):
if isinstance(prompt, dict):
if "multi_modal_data" in prompt:
raise ValueError("Multi-modal prompt is not "
"supported for cross encoding")
elif "prompt_token_ids" in prompt:
prompt = tokenizer.decode(
cast(TokensPrompt, prompt)["prompt_token_ids"])
elif "prompt" in prompt:
prompt = cast(TextPrompt, prompt)["prompt"]
assert type(prompt) is str
return prompt
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if len(text_1) == 1:
text_1 = text_1 * len(text_2)
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
pooling_params = PoolingParams()
tokenization_kwargs: Dict[str, Any] = {}
if truncate_prompt_tokens is not None:
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
parsed_prompts = []
for q, t in input_pairs:
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))
parsed_prompts.append(engine_prompt)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
EmbeddingRequestOutput)
def start_profile(self) -> None:
self.llm_engine.start_profile()