[Misc] Fix warnings for mistral model (#23552)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Jiangyun Zhu
2025-08-29 14:58:48 +08:00
committed by GitHub
parent 2d0afcc9dc
commit 885ca6d31d
3 changed files with 31 additions and 23 deletions

View File

@@ -204,18 +204,16 @@ class MistralTokenizer(TokenizerBase):
self.version: int = int(_mistral_version_str.split("v")[-1])
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.tekken import (
SpecialTokenPolicy, Tekkenizer)
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer)
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
if self.is_tekken:
# Make sure special tokens will not raise
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
elif self.is_spm:
pass
else:
self._special_token_policy = (SpecialTokenPolicy.IGNORE
if self.is_tekken else None)
if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self._vocab = tokenizer_.vocab()
@@ -430,7 +428,8 @@ class MistralTokenizer(TokenizerBase):
return self.tokenizer.unk_id
ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids)
decoded = self.tokenizer.decode(ids,
self._special_token_policy)
else:
decoded = "".join(tokens)
else:
@@ -444,7 +443,8 @@ class MistralTokenizer(TokenizerBase):
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
self.tokenizer.decode(regular_tokens,
self._special_token_policy))
regular_tokens = []
decoded_list.append(token)
else:
@@ -452,7 +452,8 @@ class MistralTokenizer(TokenizerBase):
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens)) # type: ignore
self.tokenizer.decode(regular_tokens,
self._special_token_policy))
decoded = ''.join(decoded_list)
@@ -470,7 +471,7 @@ class MistralTokenizer(TokenizerBase):
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids)
return self.tokenizer.decode(ids, self._special_token_policy)
def convert_ids_to_tokens(
self,
@@ -511,6 +512,9 @@ class MistralTokenizer(TokenizerBase):
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "<22>"
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
tokens = [
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
for id in ids
]
return tokens