Performance fix MistralTokenizer: cache special ids and tokens (#27925)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Julien Denize
2025-11-02 09:48:33 +01:00
committed by GitHub
parent 853a8eb53b
commit 73444b7b56

View File

@@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase):
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Cache special tokens for faster access.
self._special_token_ids = self._get_special_token_ids()
self._special_token_ids_set = set(self._special_token_ids)
self._special_tokens = self._get_special_tokens(self._special_token_ids)
self._special_tokens_set = set(self._special_tokens)
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1
@@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase):
)
)
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]
@property
def all_special_ids(self) -> list[int]:
def _get_special_token_ids(self) -> list[int]:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
@@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase):
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in all_special_ids
]
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
return self._special_tokens
@property
def all_special_ids(self) -> list[int]:
return self._special_token_ids
@property
def bos_token_id(self) -> int:
return self.tokenizer.bos_id
@@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase):
raise NotImplementedError()
def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return token_id in self._special_token_ids_set
def __len__(self) -> int:
return self.vocab_size
@@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase):
tokens = [
t
for t in tokens
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
if (t in to_decode_special_tokens or t not in self._special_tokens_set)
]
if any(isinstance(t, bytes) for t in tokens):
@@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase):
# We filtered unwanted special tokens so we can decode the rest.
tokens = [
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
if token_id not in self.all_special_ids
if token_id not in self._special_token_ids_set
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
]