Performance fix MistralTokenizer: cache special ids and tokens (#27925)
Signed-off-by: Julien Denize <julien.denize@mistral.ai> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase):
|
||||
# Sort the dict for convenience
|
||||
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
||||
|
||||
# Cache special tokens for faster access.
|
||||
self._special_token_ids = self._get_special_token_ids()
|
||||
self._special_token_ids_set = set(self._special_token_ids)
|
||||
self._special_tokens = self._get_special_tokens(self._special_token_ids)
|
||||
self._special_tokens_set = set(self._special_tokens)
|
||||
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer._vocab
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
@@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
)
|
||||
)
|
||||
|
||||
# the following attributes are set to fit vLLM's design and are used
|
||||
# by the structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
return self.all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
return [
|
||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
for i in self.all_special_ids
|
||||
]
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
def _get_special_token_ids(self) -> list[int]:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
@@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase):
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
return sorted(special_ids)
|
||||
|
||||
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
return [
|
||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
for i in all_special_ids
|
||||
]
|
||||
|
||||
# the following attributes are set to fit vLLM's design and are used
|
||||
# by the structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
return self.all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return self._special_tokens
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return self._special_token_ids
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return self.tokenizer.bos_id
|
||||
@@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
def _is_special_token_id(self, token_id: int) -> bool:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
if self.is_spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
||||
self.tokenizer
|
||||
)
|
||||
return token_id in self.tokenizer._control_tokens
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
return token_id < self.tokenizer.num_special_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
return token_id in self._special_token_ids_set
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
@@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
tokens = [
|
||||
t
|
||||
for t in tokens
|
||||
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
|
||||
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
|
||||
]
|
||||
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
@@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
# We filtered unwanted special tokens so we can decode the rest.
|
||||
tokens = [
|
||||
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
|
||||
if token_id not in self.all_special_ids
|
||||
if token_id not in self._special_token_ids_set
|
||||
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
|
||||
for token_id in ids_kept
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user