Support Cross encoder models (#10400)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Co-authored-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
committed by
GitHub
parent
49628fe13e
commit
214efc2c3c
@@ -20,7 +20,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -817,6 +817,128 @@ class LLM:
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
|
||||
/,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates similarity scores for all pairs <text,text_pair>.
|
||||
|
||||
The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
|
||||
the text_1 sentence will be replicated N times to pair with the text_2
|
||||
sentences. The input pairs are used to build a list of prompts for the
|
||||
cross encoder model. This class automatically batches the prompts,
|
||||
considering the memory constraint. For the best performance, put all
|
||||
of your texts into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
text_1: can be a single prompt or a list of prompts, in which
|
||||
case it has to have the same length as the text_2 list
|
||||
text_2: The texts to pair with the query to form the input
|
||||
to the LLM. See :class:`~vllm.inputs.PromptType` for
|
||||
more details about the format of each prompts.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prompt_adapter_request: Prompt Adapter request to use for
|
||||
generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of ``EmbeddingRequestOutput`` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
task = self.llm_engine.model_config.task
|
||||
if task != "embedding":
|
||||
messages = ["LLM.score() is only supported for embedding models."]
|
||||
|
||||
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||
if "embedding" in supported_tasks:
|
||||
messages.append(
|
||||
"Your model supports the 'embedding' task, but is "
|
||||
f"currently initialized for the '{task}' task. Please "
|
||||
"initialize the model using `--task embedding`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if not self.llm_engine.model_config.is_cross_encoder:
|
||||
raise ValueError("Your model does not support the cross encoding")
|
||||
|
||||
tokenizer = self.llm_engine.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"MistralTokenizer not supported for cross-encoding")
|
||||
|
||||
# the tokenizer for models such as
|
||||
# "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
|
||||
# lists of tokens to the `text` and `text_pair` kwargs
|
||||
def ensure_str(prompt: SingletonPrompt):
|
||||
if isinstance(prompt, dict):
|
||||
if "multi_modal_data" in prompt:
|
||||
raise ValueError("Multi-modal prompt is not "
|
||||
"supported for cross encoding")
|
||||
elif "prompt_token_ids" in prompt:
|
||||
prompt = tokenizer.decode(
|
||||
cast(TokensPrompt, prompt)["prompt_token_ids"])
|
||||
elif "prompt" in prompt:
|
||||
prompt = cast(TextPrompt, prompt)["prompt"]
|
||||
assert type(prompt) is str
|
||||
return prompt
|
||||
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if len(text_1) == 1:
|
||||
text_1 = text_1 * len(text_2)
|
||||
|
||||
input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
tokenization_kwargs: Dict[str, Any] = {}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenization_kwargs["truncation"] = True
|
||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||
|
||||
parsed_prompts = []
|
||||
|
||||
for q, t in input_pairs:
|
||||
prompt_inputs = tokenizer(text=q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
engine_prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["input_ids"],
|
||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||
parsed_prompts.append(engine_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
EmbeddingRequestOutput)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user