[Bugfix][Frontend] Guard against bad token ids (#9634)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-10-29 16:13:20 -05:00
committed by GitHub
parent 0ad216f575
commit 67bdf8e523
7 changed files with 89 additions and 17 deletions

View File

@@ -35,6 +35,7 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
max_token_id = max(tokenizer.get_vocab().values())
class CachedTokenizer(tokenizer.__class__): # type: ignore
@@ -50,6 +51,10 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
@property
def max_token_id(self):
return max_token_id
def __len__(self):
return tokenizer_len

View File

@@ -85,6 +85,7 @@ class MistralTokenizer:
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
self.tokenizer = tokenizer_
self._max_token_id = max(self._vocab.values())
@classmethod
def from_pretrained(cls,
@@ -158,6 +159,10 @@ class MistralTokenizer:
def vocab_size(self) -> int:
return len(self._vocab)
@property
def max_token_id(self) -> int:
return self._max_token_id
def __len__(self) -> int:
return self.vocab_size