diff --git a/tests/models/multimodal/generation/vlm_utils/model_utils.py b/tests/models/multimodal/generation/vlm_utils/model_utils.py index 9bdedb3c5..0a692387c 100644 --- a/tests/models/multimodal/generation/vlm_utils/model_utils.py +++ b/tests/models/multimodal/generation/vlm_utils/model_utils.py @@ -24,6 +24,7 @@ from transformers import ( GenerationConfig, GenerationMixin, ) +from transformers.masking_utils import create_causal_mask from transformers.video_utils import VideoMetadata from vllm.logprobs import SampleLogprobs @@ -680,10 +681,14 @@ def isaac_patch_hf_runner(hf_model: HfRunner) -> HfRunner: sin = sin.to(inputs_embeds.dtype) # Prepare attention mask - if attention_mask is not None: - attention_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, False - ) + attention_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + position_ids=position_ids, + cache_position=cache_position, + ) # Initialize and collect hidden states hidden_states = inputs_embeds