[CI/Build][Bugfix] Ensure compatibility with transformers 4.52 (#18678)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,11 +10,12 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
|
||||
GenerationConfig)
|
||||
GenerationConfig, GenerationMixin)
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import patch_padding_side
|
||||
@@ -324,6 +325,16 @@ def gemma3_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
|
||||
hf_model.processor = processor
|
||||
|
||||
orig_generate = hf_model.model.generate
|
||||
|
||||
def _generate(self, *args, **kwargs):
|
||||
# FIXME: https://github.com/huggingface/transformers/issues/38333
|
||||
kwargs["disable_compile"] = True
|
||||
|
||||
return orig_generate(*args, **kwargs)
|
||||
|
||||
hf_model.model.generate = types.MethodType(_generate, hf_model.model)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
@@ -610,6 +621,11 @@ def _internvl_generate(
|
||||
if getattr(self, "use_visual_token_mask", False):
|
||||
visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype)
|
||||
forward_kwargs["visual_token_mask"] = visual_token_mask
|
||||
|
||||
# e.g. InternVL2-2B
|
||||
if not isinstance(self.language_model, GenerationMixin):
|
||||
pytest.skip("HF impl is not compatible with current transformers")
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
**forward_kwargs,
|
||||
**generate_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user