[VLM] Merged multi-modal processor and V1 support for Qwen-VL (#12504)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -16,7 +16,6 @@ from ...registry import HF_EXAMPLE_MODELS
|
||||
|
||||
def _test_processing_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
@@ -25,11 +24,6 @@ def _test_processing_correctness(
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
limit_mm_per_prompt = {
|
||||
modality: 3 if supports_multi else 1
|
||||
for modality, supports_multi in modalities.items()
|
||||
}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
@@ -40,18 +34,29 @@ def _test_processing_correctness(
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
),
|
||||
)
|
||||
# Ensure that it can fit all of the data
|
||||
cache = ProcessingCache(capacity=1 << 30)
|
||||
|
||||
processing_info = factories.info(ctx)
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
limit_mm_per_prompt = {
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
|
||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||
dummy_inputs = baseline_processor.dummy_inputs
|
||||
@@ -82,8 +87,8 @@ def _test_processing_correctness(
|
||||
mm_data = {
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
|
||||
for k in modalities
|
||||
for _ in range(rng.randint(limit))]
|
||||
for k, limit in limit_mm_per_prompt.items()
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
@@ -135,21 +140,22 @@ def _test_processing_correctness(
|
||||
|
||||
# yapf: disable
|
||||
# True if the model supports multiple data items of the modality per request
|
||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||
("rhymes-ai/Aria", {"image": True}),
|
||||
("Salesforce/blip2-opt-2.7b", {"image": False}),
|
||||
("facebook/chameleon-7b", {"image": False}),
|
||||
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
|
||||
("adept/fuyu-8b", {"image": False}),
|
||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
||||
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||
("mistral-community/pixtral-12b", {"image": True}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
|
||||
("fixie-ai/ultravox-v0_3", {"audio": True}),
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"rhymes-ai/Aria",
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
"facebook/chameleon-7b",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"adept/fuyu-8b",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistral-community/pixtral-12b",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@@ -157,14 +163,12 @@ def _test_processing_correctness(
|
||||
# yapf: enable
|
||||
def test_processing_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
@@ -172,16 +176,13 @@ def test_processing_correctness(
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
||||
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
|
||||
])
|
||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_correctness_phi3v(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
@@ -195,7 +196,6 @@ def test_processing_correctness_phi3v(
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
"""Tests for Qwen's multimodal preprocessing kwargs."""
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from ....conftest import IMAGE_ASSETS
|
||||
from ...utils import build_model_context
|
||||
|
||||
### Multimodal preprocessing tests
|
||||
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
|
||||
# These values are specific to Qwen-VL/Chat; we can get these from the model
|
||||
# config also, but they are hardcoded here to keep the parameterize/fixtures
|
||||
# easy to read.
|
||||
IMG_START_ID = 151857
|
||||
IMG_END_ID = 151858
|
||||
IMG_PAD_ID = 151859
|
||||
TOKS_PER_IMG = 256
|
||||
VIS_ENC_DIM = 4096
|
||||
IMG_SIZE = 448
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_mapper_for_qwen():
|
||||
# Lazy import to avoid initializing CUDA during test collection
|
||||
from vllm.model_executor.models.qwen import input_mapper_for_qwen
|
||||
return input_mapper_for_qwen
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def input_processor_for_qwen():
|
||||
# Lazy import to avoid initializing CUDA during test collection
|
||||
from vllm.model_executor.models.qwen import input_processor_for_qwen
|
||||
return input_processor_for_qwen
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def qwen_vl_context() -> InputContext:
|
||||
"""Get an InputContext for Qwen-VL."""
|
||||
return build_model_context(model_name="Qwen/Qwen-VL",
|
||||
trust_remote_code=True)
|
||||
|
||||
|
||||
# Happy path tests for single/multi-image scenarios for the multimodal
|
||||
# input processor and mapper, respectively
|
||||
@pytest.mark.parametrize("num_images", [1, 2])
|
||||
def test_input_processor_valid_mm_data(input_processor_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
num_images: int):
|
||||
"""Happy cases for image inputs to Qwen's multimodal input processor."""
|
||||
prompt = "".join(
|
||||
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
|
||||
inputs = token_inputs(
|
||||
prompt=prompt,
|
||||
# When processing multimodal data for a multimodal model, the qwen
|
||||
# input processor will overwrite the provided prompt_token_ids with
|
||||
# the image prompts
|
||||
prompt_token_ids=[],
|
||||
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
|
||||
)
|
||||
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
assert isinstance(proc_inputs, dict)
|
||||
|
||||
# Each image should have one start / stop and a fixed context of 256
|
||||
proc_tokens = proc_inputs["prompt_token_ids"]
|
||||
assert proc_tokens.count(IMG_START_ID) == num_images
|
||||
assert proc_tokens.count(IMG_END_ID) == num_images
|
||||
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"img_data,expected_shape",
|
||||
[
|
||||
# single / multi-image
|
||||
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
|
||||
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
|
||||
# single / multi-image embeddings
|
||||
(torch.rand(
|
||||
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
(torch.rand(
|
||||
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
(torch.rand(
|
||||
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
|
||||
])
|
||||
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
img_data: Union[torch.Tensor, List[Image],
|
||||
Image],
|
||||
expected_shape: List[int]):
|
||||
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
|
||||
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
|
||||
# Ensure that we get the appropriately shaped pixel_values
|
||||
# for images and image embeddings, respectively.
|
||||
assert isinstance(mapped_img_data, MultiModalKwargs)
|
||||
assert "pixel_values" in mapped_img_data
|
||||
assert mapped_img_data["pixel_values"].shape == expected_shape
|
||||
|
||||
|
||||
# Sad path tests for the multimodal input processor and mapper, respectively
|
||||
@pytest.mark.parametrize("mm_data", [
|
||||
{
|
||||
"image": torch.rand(5)
|
||||
},
|
||||
{
|
||||
"image": torch.rand((5, 5, 5, 5, 5))
|
||||
},
|
||||
])
|
||||
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
mm_data: Dict[str, torch.Tensor]):
|
||||
"""Test sad cases validated in Qwen's multimodal input processor."""
|
||||
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
prompt = "Picture 1: <img></img>\n"
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
inputs = token_inputs(prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=mm_data)
|
||||
# Should fail since we have too many or too few dimensions for embeddings
|
||||
with pytest.raises(ValueError):
|
||||
input_processor_for_qwen(qwen_vl_context, inputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"img_data",
|
||||
[
|
||||
# Wrong context length
|
||||
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
|
||||
# Wrong visual encoder output size
|
||||
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
|
||||
])
|
||||
def test_input_mapper_invalid_mm_data(
|
||||
input_mapper_for_qwen,
|
||||
qwen_vl_context: InputContext,
|
||||
img_data: Union[torch.Tensor, List[Image], Image],
|
||||
):
|
||||
"""Sad cases validated in Qwen VL's multimodal input mapper."""
|
||||
with pytest.raises(ValueError):
|
||||
input_mapper_for_qwen(qwen_vl_context, img_data)
|
||||
Reference in New Issue
Block a user