diff --git a/tests/models/multimodal/pooling/test_clip.py b/tests/models/multimodal/pooling/test_clip.py index 95c678558..14ede6c1d 100644 --- a/tests/models/multimodal/pooling/test_clip.py +++ b/tests/models/multimodal/pooling/test_clip.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch from transformers import CLIPModel from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner @@ -50,13 +51,16 @@ def _run_test( if "pixel_values" in inputs: pooled_output = hf_model.model.get_image_features( pixel_values=inputs.pixel_values, - ).squeeze(0) + ) else: pooled_output = hf_model.model.get_text_features( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, - ).squeeze(0) + ) + if not isinstance(pooled_output, torch.Tensor): + pooled_output = pooled_output.pooler_output + pooled_output = pooled_output.squeeze(0) all_outputs.append(pooled_output.tolist()) hf_outputs = all_outputs diff --git a/tests/models/multimodal/pooling/test_siglip.py b/tests/models/multimodal/pooling/test_siglip.py index 0b8cd33cc..4617250e3 100644 --- a/tests/models/multimodal/pooling/test_siglip.py +++ b/tests/models/multimodal/pooling/test_siglip.py @@ -4,6 +4,7 @@ from typing import Any import pytest +import torch from transformers import SiglipModel from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner @@ -68,12 +69,15 @@ def _run_test( if "pixel_values" in inputs: pooled_output = hf_model.model.get_image_features( pixel_values=inputs.pixel_values, - ).squeeze(0) + ) else: pooled_output = hf_model.model.get_text_features( input_ids=inputs.input_ids, - ).squeeze(0) + ) + if not isinstance(pooled_output, torch.Tensor): + pooled_output = pooled_output.pooler_output + pooled_output = pooled_output.squeeze(0) all_outputs.append(pooled_output.tolist()) hf_outputs = all_outputs