[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM (#12518)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -140,7 +141,7 @@ def make_mistral_chat_completion_request(
|
||||
tools=tools) # type: ignore[type-var]
|
||||
|
||||
|
||||
class MistralTokenizer:
|
||||
class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
@@ -251,6 +252,14 @@ class MistralTokenizer:
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
@@ -268,25 +277,26 @@ class MistralTokenizer:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[int]],
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
# For List[str], original prompt text
|
||||
if is_list_of(prompt, str):
|
||||
if is_list_of(text, str):
|
||||
input_ids_: List[List[int]] = []
|
||||
for p in prompt:
|
||||
for p in text:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For List[int], apply chat template output, already tokens.
|
||||
elif is_list_of(prompt, int):
|
||||
input_ids = prompt
|
||||
elif is_list_of(text, int):
|
||||
input_ids = text
|
||||
# For str, single prompt text
|
||||
else:
|
||||
input_ids = self.encode_one(prompt, truncation, max_length)
|
||||
input_ids = self.encode_one(text, truncation, max_length)
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
@@ -300,22 +310,29 @@ class MistralTokenizer:
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
prompt: str,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
input_ids = self.encode(text)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
return input_ids
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
if add_special_tokens is not None:
|
||||
return self.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
else:
|
||||
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
|
||||
Reference in New Issue
Block a user