[Misc] Make cached tokenizer pickle-compatible (#17048)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
@@ -70,18 +71,17 @@ def encode_tokens(
|
||||
|
||||
|
||||
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
"""Get tokenizer with cached properties.
|
||||
|
||||
This will patch the tokenizer object in place.
|
||||
|
||||
"""
|
||||
By default, transformers will recompute multiple tokenizer properties
|
||||
each time they are called, leading to a significant slowdown. This
|
||||
function caches these properties for faster access."""
|
||||
each time they are called, leading to a significant slowdown.
|
||||
This proxy caches these properties for faster access.
|
||||
"""
|
||||
cached_tokenizer = copy.copy(tokenizer)
|
||||
|
||||
tokenizer_all_special_ids = set(tokenizer.all_special_ids)
|
||||
tokenizer_all_special_ids = tokenizer.all_special_ids
|
||||
tokenizer_all_special_tokens = tokenizer.all_special_tokens
|
||||
tokenizer_all_special_tokens_extended = (
|
||||
tokenizer.all_special_tokens_extended)
|
||||
tokenizer_all_special_tokens = set(tokenizer.all_special_tokens)
|
||||
tokenizer_vocab = tokenizer.get_vocab()
|
||||
tokenizer_len = len(tokenizer)
|
||||
|
||||
@@ -97,31 +97,34 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
class CachedTokenizer(tokenizer.__class__): # type: ignore
|
||||
|
||||
@property
|
||||
def all_special_ids(self):
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return tokenizer_all_special_ids
|
||||
|
||||
@property
|
||||
def all_special_tokens(self):
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return tokenizer_all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens_extended(self):
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
return tokenizer_all_special_tokens_extended
|
||||
|
||||
@property
|
||||
def max_token_id(self):
|
||||
def max_token_id(self) -> int:
|
||||
return max_token_id
|
||||
|
||||
def get_vocab(self):
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
return tokenizer_vocab
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return tokenizer_len
|
||||
|
||||
def __reduce__(self):
|
||||
return get_cached_tokenizer, (tokenizer, )
|
||||
|
||||
CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"
|
||||
|
||||
tokenizer.__class__ = CachedTokenizer
|
||||
return tokenizer
|
||||
cached_tokenizer.__class__ = CachedTokenizer
|
||||
return cached_tokenizer
|
||||
|
||||
|
||||
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
|
||||
|
||||
Reference in New Issue
Block a user