[Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (#16424)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-04-11 12:57:16 +08:00
committed by GitHub
parent ed37599544
commit 93195146ea
5 changed files with 118 additions and 45 deletions

View File

@@ -2,18 +2,22 @@
import os
import re
from collections.abc import Sequence
from typing import Optional
import librosa
import pytest
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from vllm.multimodal.image import rescale_image_size
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
PromptImageInput, VllmRunner)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
@@ -29,6 +33,8 @@ model_path = snapshot_download("microsoft/Phi-4-multimodal-instruct")
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
speech_question = os.path.join(model_path, "examples",
"what_is_shown_in_this_image.wav")
models = [model_path]
@@ -64,7 +70,8 @@ if current_platform.is_rocm():
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], PromptImageInput]],
inputs: Sequence[tuple[list[str], PromptImageInput,
Optional[PromptAudioInput]]],
model: str,
*,
max_model_len: int,
@@ -104,28 +111,49 @@ def run_test(
enforce_eager=True,
) as vllm_model:
lora_request = LoRARequest("vision", 1, vision_lora_path)
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
vllm_outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
images=images,
audios=audios,
lora_request=lora_request)
for prompts, images, audios in inputs
]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"}
hf_model_kwargs = {"_attn_implementation": "sdpa"}
with hf_runner(model, dtype=dtype,
model_kwargs=hf_model_kwargs) as hf_model:
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_processor = hf_model.processor
eos_token_id = hf_processor.tokenizer.eos_token_id
def patch_hf_processor(*args,
text="",
images=None,
audio=None,
sampling_rate=None,
**kwargs):
audios = None
if audio is not None and sampling_rate is not None:
audios = [(audio, sampling_rate)]
return hf_processor(*args,
text=text,
images=images,
audios=audios,
**kwargs)
hf_model.processor = patch_hf_processor
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
audios=audios,
eos_token_id=eos_token_id,
num_logits_to_keep=0)
for prompts, images in inputs
for prompts, images, audios in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
@@ -138,8 +166,6 @@ def run_test(
)
# Since we use _attn_implementation="eager" for hf_runner, there is more
# significant numerical difference. The basic `logprobs=5` fails to pass.
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
@@ -151,7 +177,7 @@ def run_test(
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.7, 0.75, 1.0],
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
None,
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
@@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.xfail(
reason="Phi-4-MM multi-image inference is divergent with hf model.")
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
size_factors, dtype: str, max_model_len: int,
max_tokens: int, num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_case = [
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors])
(
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
[[rescale_image_size(image, factor) for image in images]
for factor in size_factors],
None,
),
]
run_test(
@@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
mm_limit=2,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_model_len", [10000])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
max_model_len: int, max_tokens: int,
num_logprobs: int) -> None:
# use the example speech question so that the model outputs are reasonable
audio = librosa.load(speech_question, sr=None)
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
inputs_vision_speech = [
(
["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
[image],
[audio],
),
]
run_test(
hf_runner,
vllm_runner,
inputs_vision_speech,
model,
dtype=dtype,
max_model_len=max_model_len,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)