From 73444b7b5623f5bc569277c8c7dc809843312d11 Mon Sep 17 00:00:00 2001 From: Julien Denize <40604584+juliendenize@users.noreply.github.com> Date: Sun, 2 Nov 2025 09:48:33 +0100 Subject: [PATCH] Performance fix MistralTokenizer: cache special ids and tokens (#27925) Signed-off-by: Julien Denize Co-authored-by: Patrick von Platen --- vllm/transformers_utils/tokenizers/mistral.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 6f710bf23..703352322 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -191,6 +191,12 @@ class MistralTokenizer(TokenizerBase): # Sort the dict for convenience self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1])) + # Cache special tokens for faster access. + self._special_token_ids = self._get_special_token_ids() + self._special_token_ids_set = set(self._special_token_ids) + self._special_tokens = self._get_special_tokens(self._special_token_ids) + self._special_tokens_set = set(self._special_tokens) + # Vocab sorted by token id. self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 @@ -210,23 +216,7 @@ class MistralTokenizer(TokenizerBase): ) ) - # the following attributes are set to fit vLLM's design and are used - # by the structured output backends. - @property - def all_special_tokens_extended(self) -> list[str]: - return self.all_special_tokens - - @property - def all_special_tokens(self) -> list[str]: - from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy - - return [ - self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) - for i in self.all_special_ids - ] - - @property - def all_special_ids(self) -> list[int]: + def _get_special_token_ids(self) -> list[int]: from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, ) @@ -244,6 +234,28 @@ class MistralTokenizer(TokenizerBase): raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") return sorted(special_ids) + def _get_special_tokens(self, all_special_ids: list[int]) -> list[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy + + return [ + self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP) + for i in all_special_ids + ] + + # the following attributes are set to fit vLLM's design and are used + # by the structured output backends. + @property + def all_special_tokens_extended(self) -> list[str]: + return self.all_special_tokens + + @property + def all_special_tokens(self) -> list[str]: + return self._special_tokens + + @property + def all_special_ids(self) -> list[int]: + return self._special_token_ids + @property def bos_token_id(self) -> int: return self.tokenizer.bos_id @@ -277,21 +289,7 @@ class MistralTokenizer(TokenizerBase): raise NotImplementedError() def _is_special_token_id(self, token_id: int) -> bool: - from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer, - ) - from mistral_common.tokens.tokenizers.tekken import Tekkenizer - - if self.is_spm: - assert isinstance(self.tokenizer, SentencePieceTokenizer), type( - self.tokenizer - ) - return token_id in self.tokenizer._control_tokens - if self.is_tekken: - assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer) - return token_id < self.tokenizer.num_special_tokens - else: - raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}") + return token_id in self._special_token_ids_set def __len__(self) -> int: return self.vocab_size @@ -405,7 +403,7 @@ class MistralTokenizer(TokenizerBase): tokens = [ t for t in tokens - if (t in to_decode_special_tokens or t not in self.all_special_tokens) + if (t in to_decode_special_tokens or t not in self._special_tokens_set) ] if any(isinstance(t, bytes) for t in tokens): @@ -489,7 +487,7 @@ class MistralTokenizer(TokenizerBase): # We filtered unwanted special tokens so we can decode the rest. tokens = [ self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP) - if token_id not in self.all_special_ids + if token_id not in self._special_token_ids_set else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP) for token_id in ids_kept ]