[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
Keyun Tong
2025-02-11 20:25:58 -08:00
committed by GitHub
parent 72c2b68dc9
commit 3ee696a63d
11 changed files with 343 additions and 41 deletions

View File

@@ -1051,9 +1051,9 @@ class LLM:
def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
@@ -1176,29 +1176,36 @@ class LLM:
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]
if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]
if len(text_1) > 1 and len(text_1) != len(text_2):
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")
if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)
return self._embedding_score(
tokenizer,
input_text_1, # type: ignore[arg-type]
input_text_2, # type: ignore[arg-type]
truncate_prompt_tokens,
use_tqdm,
lora_request,
prompt_adapter_request)
def start_profile(self) -> None:
self.llm_engine.start_profile()

View File

@@ -400,8 +400,7 @@ class OpenAIServing:
_chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,

View File

@@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)