[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
Keyun Tong
2025-02-11 20:25:58 -08:00
committed by GitHub
parent 72c2b68dc9
commit 3ee696a63d
11 changed files with 343 additions and 41 deletions

View File

@@ -10,6 +10,7 @@ import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of
if TYPE_CHECKING:
@@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
tools=tools) # type: ignore[type-var]
class MistralTokenizer:
class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
@@ -251,6 +252,14 @@ class MistralTokenizer:
def eos_token_id(self) -> int:
return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
@property
def is_fast(self) -> bool:
return True
@@ -268,25 +277,26 @@ class MistralTokenizer:
def __call__(
self,
prompt: Union[str, List[str], List[int]],
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids: Union[List[int], List[List[int]]]
# For List[str], original prompt text
if is_list_of(prompt, str):
if is_list_of(text, str):
input_ids_: List[List[int]] = []
for p in prompt:
for p in text:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For List[int], apply chat template output, already tokens.
elif is_list_of(prompt, int):
input_ids = prompt
elif is_list_of(text, int):
input_ids = text
# For str, single prompt text
else:
input_ids = self.encode_one(prompt, truncation, max_length)
input_ids = self.encode_one(text, truncation, max_length)
return Encoding(input_ids=input_ids)
def get_vocab(self) -> Dict[str, int]:
@@ -300,22 +310,29 @@ class MistralTokenizer:
def encode_one(
self,
prompt: str,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(prompt)
input_ids = self.encode(text)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
def encode(self, prompt: str) -> List[int]:
def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)
if add_special_tokens is not None:
return self.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
else:
return self.tokenizer.encode(text, bos=True, eos=False)
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],