[Model] Add GLM-4v support and meet vllm==0.6.2 (#9242)
This commit is contained in:
@@ -59,6 +59,26 @@ def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||
return tokenizer
|
||||
|
||||
|
||||
def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None:
|
||||
"""Patch _pad method to accept `padding_side` for older tokenizers."""
|
||||
orig_pad = tokenizer._pad
|
||||
|
||||
def _pad(
|
||||
self: PreTrainedTokenizer,
|
||||
*args,
|
||||
padding_side: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if padding_side is not None and padding_side != self.padding_side:
|
||||
msg = ("`padding_side` argument is not supported by "
|
||||
f"{type(tokenizer).__name__} and will be ignored.")
|
||||
warnings.warn(msg, stacklevel=2)
|
||||
|
||||
return orig_pad(*args, **kwargs)
|
||||
|
||||
tokenizer._pad = MethodType(_pad, tokenizer)
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
tokenizer_name: Union[str, Path],
|
||||
*args,
|
||||
@@ -143,24 +163,7 @@ def get_tokenizer(
|
||||
if type(tokenizer).__name__ in ("ChatGLMTokenizer",
|
||||
"ChatGLM4Tokenizer"):
|
||||
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||
orig_pad = tokenizer._pad
|
||||
|
||||
# Patch _pad method to accept `padding_side`
|
||||
def _pad(
|
||||
self: PreTrainedTokenizer,
|
||||
*args,
|
||||
padding_side: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if (padding_side is not None
|
||||
and padding_side != self.padding_side):
|
||||
msg = ("`padding_side` argument is not supported by "
|
||||
"ChatGLMTokenizer and will be ignored.")
|
||||
warnings.warn(msg, stacklevel=2)
|
||||
|
||||
return orig_pad(*args, **kwargs)
|
||||
|
||||
tokenizer._pad = MethodType(_pad, tokenizer)
|
||||
patch_padding_side(tokenizer)
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user