[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-03 00:39:27 +08:00
committed by GitHub
parent b6087a6bee
commit 8c38ee7007
14 changed files with 609 additions and 555 deletions

View File

@@ -1,5 +1,7 @@
from contextlib import nullcontext
from functools import partial
from typing import cast
from unittest.mock import MagicMock
import numpy as np
import pytest
@@ -526,6 +528,100 @@ def _rand_audio(
return rng.rand(audio_len), sr
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
limit_mm_per_prompt = {"image": limit}
model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
processor.get_supported_mm_limits = mock_supported_mm_limits
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match="this model only supports")
with exc_ctx:
processor._get_and_validate_dummy_mm_counts()
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("num_images", "limit", "is_valid"),
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
(2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
limit_mm_per_prompt = {"image": limit}
model_config = ModelConfig(
model=model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="half",
revision=None,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
)
processor = processor_factory(ctx, cache=None)
rng = np.random.RandomState(0)
image = _rand_img(rng, min_wh=128, max_wh=256)
if num_images == 0:
mm_data = {}
elif num_images == 1:
mm_data = {"image": image}
else:
mm_data = {"image": [image] * num_images}
if is_valid:
exc_ctx = nullcontext()
else:
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
with exc_ctx:
processor.apply(
"<image>" * num_images,
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
def _test_processing_cache_correctness(
model_id: str,
modalities: dict[str, bool],
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),