[Core][Bugfix] cache len of tokenizer (#3741)

This commit is contained in:
youkaichao
2024-03-29 18:46:39 -07:00
committed by GitHub
parent 991143cfcd
commit 203d4f82ac

View File

@@ -26,6 +26,7 @@ def get_cached_tokenizer(
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_len = len(tokenizer)
class CachedTokenizer(tokenizer.__class__):
@@ -41,6 +42,9 @@ def get_cached_tokenizer(
def all_special_tokens_extended(self):
return tokenizer_all_special_tokens_extended
def __len__(self):
return tokenizer_len
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer