[Misc] Make cached tokenizer pickle-compatible (#17048)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-27 13:05:00 +08:00
committed by GitHub
parent 8e4b351a0c
commit 93a126fbc7
5 changed files with 81 additions and 57 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import contextlib
import copy
import os
import warnings
from functools import lru_cache
@@ -70,18 +71,17 @@ def encode_tokens(
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
"""Get tokenizer with cached properties.
This will patch the tokenizer object in place.
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown. This
function caches these properties for faster access."""
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_all_special_tokens_extended = (
tokenizer.all_special_tokens_extended)
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)
@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self):
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids
@property
def all_special_tokens(self):
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens
@property
def all_special_tokens_extended(self):
def all_special_tokens_extended(self) -> list[str]:
return tokenizer_all_special_tokens_extended
@property
def max_token_id(self):
def max_token_id(self) -> int:
return max_token_id
def get_vocab(self):
def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab
def __len__(self):
def __len__(self) -> int:
return tokenizer_len
def __reduce__(self):
return get_cached_tokenizer, (tokenizer, )
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
tokenizer.__class__ = CachedTokenizer
return tokenizer
cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: