[Bugfix] Fix edge cases for MistralTokenizer (#9625)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Prashant Gupta <prashantgupta@us.ibm.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,9 +16,13 @@ from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Encoding:
|
||||
@@ -72,20 +76,21 @@ class MistralTokenizer:
|
||||
# Make sure special tokens will not raise
|
||||
tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
|
||||
|
||||
self._vocab = {
|
||||
token: idx
|
||||
for idx, token in enumerate(tokenizer_.vocab())
|
||||
}
|
||||
elif isinstance(tokenizer_, SentencePieceTokenizer):
|
||||
self._vocab = {
|
||||
token: idx
|
||||
for idx, token in enumerate(tokenizer_.vocab())
|
||||
}
|
||||
pass
|
||||
else:
|
||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
|
||||
self._vocab = tokenizer_.vocab()
|
||||
# Convert to a Dict[str, int] to match protocol, but this is a lossy
|
||||
# conversion. There may be multiple token ids that decode to the same
|
||||
# string due to partial UTF-8 byte sequences being converted to <20>
|
||||
self._vocab_dict = {
|
||||
token: idx
|
||||
for idx, token in enumerate(self._vocab)
|
||||
}
|
||||
self.tokenizer = tokenizer_
|
||||
self._max_token_id = max(self._vocab.values())
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
@@ -182,7 +187,9 @@ class MistralTokenizer:
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
return self._vocab
|
||||
# NB: the dictionary form of the vocabulary collapses token ids that map
|
||||
# to the same string but have different bytes
|
||||
return self._vocab_dict
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
# Mistral tokenizers have no added vocabulary
|
||||
@@ -220,14 +227,20 @@ class MistralTokenizer:
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
# we need to encode and decode all tokens again
|
||||
shift = self.tokenizer.num_special_tokens
|
||||
byte_tokens = [
|
||||
t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||
for t in tokens
|
||||
]
|
||||
ids = [
|
||||
self.tokenizer._tekken_token2id_nospecial[t] + shift
|
||||
for t in byte_tokens
|
||||
]
|
||||
|
||||
def _token_to_id(t: str):
|
||||
t_bytes = t.encode("utf-8") \
|
||||
if not isinstance(t, bytes) else t
|
||||
try:
|
||||
return shift + \
|
||||
self.tokenizer._tekken_token2id_nospecial[t_bytes]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Failed to convert token %s to id,"
|
||||
" replacing with <unk>", t_bytes)
|
||||
return self.tokenizer.unk_id
|
||||
|
||||
ids = [_token_to_id(t) for t in tokens]
|
||||
decoded = self.tokenizer.decode(ids)
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
@@ -236,7 +249,13 @@ class MistralTokenizer:
|
||||
|
||||
return decoded
|
||||
|
||||
def decode(self, ids: Union[List[int], int]) -> str:
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
assert (
|
||||
skip_special_tokens
|
||||
), "Skipping special tokens is not supported for Mistral tokenizers."
|
||||
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.tokenizer.decode(ids)
|
||||
@@ -257,10 +276,11 @@ class MistralTokenizer:
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
|
||||
if any(t.strip() == "<EFBFBD>" for t in tokens):
|
||||
# if any stripped decoded token is undefined
|
||||
# because it's invalid unicode then pass bytes
|
||||
if any("<EFBFBD>" in t for t in tokens):
|
||||
# if a decoded token contains the replacement character, then the
|
||||
# token has an incomplete UTF-8 character so we must use bytes
|
||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||
# https://github.com/vllm-project/vllm/pull/9625
|
||||
tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
|
||||
|
||||
return tokens
|
||||
|
||||
Reference in New Issue
Block a user