[Misc] Fix warnings for mistral model (#23552)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com> Signed-off-by: Jiangyun Zhu <riverclouds.zhu@qq.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -204,18 +204,16 @@ class MistralTokenizer(TokenizerBase):
|
||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
from mistral_common.tokens.tokenizers.tekken import (
|
||||
SpecialTokenPolicy, Tekkenizer)
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer)
|
||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||
if self.is_tekken:
|
||||
# Make sure special tokens will not raise
|
||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
elif self.is_spm:
|
||||
pass
|
||||
else:
|
||||
self._special_token_policy = (SpecialTokenPolicy.IGNORE
|
||||
if self.is_tekken else None)
|
||||
if not (self.is_tekken or self.is_spm):
|
||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
|
||||
self._vocab = tokenizer_.vocab()
|
||||
@@ -430,7 +428,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
return self.tokenizer.unk_id
|
||||
|
||||
ids = [_token_to_id(t) for t in tokens]
|
||||
decoded = self.tokenizer.decode(ids)
|
||||
decoded = self.tokenizer.decode(ids,
|
||||
self._special_token_policy)
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
else:
|
||||
@@ -444,7 +443,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
if token in special_tokens:
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens))
|
||||
self.tokenizer.decode(regular_tokens,
|
||||
self._special_token_policy))
|
||||
regular_tokens = []
|
||||
decoded_list.append(token)
|
||||
else:
|
||||
@@ -452,7 +452,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens)) # type: ignore
|
||||
self.tokenizer.decode(regular_tokens,
|
||||
self._special_token_policy))
|
||||
|
||||
decoded = ''.join(decoded_list)
|
||||
|
||||
@@ -470,7 +471,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.tokenizer.decode(ids)
|
||||
return self.tokenizer.decode(ids, self._special_token_policy)
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
@@ -511,6 +512,9 @@ class MistralTokenizer(TokenizerBase):
|
||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||
# https://github.com/vllm-project/vllm/pull/9625
|
||||
# if underlying tokenizeir is sentencepiece, we just add "<22>"
|
||||
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
||||
tokens = [
|
||||
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
|
||||
for id in ids
|
||||
]
|
||||
|
||||
return tokens
|
||||
|
||||
Reference in New Issue
Block a user