[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
@@ -1051,9 +1051,9 @@ class LLM:
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
text_1: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
text_2: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
tokenizer: AnyTokenizer,
|
||||
text_1: List[str],
|
||||
text_2: List[str],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
@@ -1176,29 +1176,36 @@ class LLM:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
input_text_1: List[str] = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
input_text_2: List[str] = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
if len(input_text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
if len(input_text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if self.llm_engine.model_config.is_cross_encoder:
|
||||
return self._cross_encoding_score(tokenizer, text_1, text_2,
|
||||
return self._cross_encoding_score(tokenizer, input_text_1,
|
||||
input_text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
else:
|
||||
return self._embedding_score(tokenizer, text_1, text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
input_text_1, # type: ignore[arg-type]
|
||||
input_text_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
@@ -400,8 +400,7 @@ class OpenAIServing:
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
|
||||
@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=q,
|
||||
prompt_inputs = await tokenize_async(q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user