[BugFix] Fix test breakages from transformers 4.45 upgrade (#8829)

This commit is contained in:
Nick Hill
2024-09-27 00:46:43 +01:00
committed by GitHub
parent 71d21c73ab
commit 4b377d6feb
13 changed files with 62 additions and 49 deletions

View File

@@ -1,6 +1,7 @@
import os
import warnings
from pathlib import Path
from types import MethodType
from typing import Optional, Union
import huggingface_hub
@@ -152,6 +153,29 @@ def get_tokenizer(
else:
raise e
# NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
"ChatGLM4Tokenizer"):
assert isinstance(tokenizer, PreTrainedTokenizer)
orig_pad = tokenizer._pad
# Patch _pad method to accept `padding_side`
def _pad(
self: PreTrainedTokenizer,
*args,
padding_side: Optional[str] = None,
**kwargs,
):
if (padding_side is not None
and padding_side != self.padding_side):
msg = ("`padding_side` argument is not supported by "
"ChatGLMTokenizer and will be ignored.")
warnings.warn(msg, stacklevel=2)
return orig_pad(*args, **kwargs)
tokenizer._pad = MethodType(_pad, tokenizer)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
"Using a slow tokenizer. This might cause a significant "
@@ -167,7 +191,7 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
return None
try:
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
except OSError as e:
except Exception as e:
# No tokenizer was found in the LoRA folder,
# use base model tokenizer
logger.warning(