[VLM] Merged multi-modal processor for LLaVA-NeXT (#11682)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -526,6 +528,100 @@ def _rand_audio(
|
||||
return rng.rand(audio_len), sr
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("limit", "num_supported", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
|
||||
limit_mm_per_prompt = {"image": limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
|
||||
mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
|
||||
processor.get_supported_mm_limits = mock_supported_mm_limits
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match="this model only supports")
|
||||
|
||||
with exc_ctx:
|
||||
processor._get_and_validate_dummy_mm_counts()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("num_images", "limit", "is_valid"),
|
||||
[(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
|
||||
(2, 1, False), (2, 2, True)],
|
||||
)
|
||||
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
limit_mm_per_prompt = {"image": limit}
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="half",
|
||||
revision=None,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
|
||||
processor_factory = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
ctx = InputProcessingContext(
|
||||
model_config,
|
||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
||||
)
|
||||
|
||||
processor = processor_factory(ctx, cache=None)
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
image = _rand_img(rng, min_wh=128, max_wh=256)
|
||||
if num_images == 0:
|
||||
mm_data = {}
|
||||
elif num_images == 1:
|
||||
mm_data = {"image": image}
|
||||
else:
|
||||
mm_data = {"image": [image] * num_images}
|
||||
|
||||
if is_valid:
|
||||
exc_ctx = nullcontext()
|
||||
else:
|
||||
exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")
|
||||
|
||||
with exc_ctx:
|
||||
processor.apply(
|
||||
"<image>" * num_images,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_cache_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
@@ -631,6 +727,7 @@ def _test_processing_cache_correctness(
|
||||
("facebook/chameleon-7b", {"image": False}),
|
||||
("adept/fuyu-8b", {"image": False}),
|
||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||
("mistral-community/pixtral-12b", {"image": True}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||
|
||||
Reference in New Issue
Block a user