[VLM] Merged multi-modal processor and V1 support for Qwen-VL (#12504)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-29 00:25:05 +08:00
committed by GitHub
parent 2079e43bee
commit 8f58a51358
4 changed files with 381 additions and 471 deletions

View File

@@ -16,7 +16,6 @@ from ...registry import HF_EXAMPLE_MODELS
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@@ -25,11 +24,6 @@ def _test_processing_correctness(
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}
model_config = ModelConfig(
model_id,
task="auto",
@@ -40,18 +34,29 @@ def _test_processing_correctness(
dtype="float16",
revision=None,
hf_overrides=model_info.hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
tokenizer=cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_info.trust_remote_code,
),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
@@ -82,8 +87,8 @@ def _test_processing_correctness(
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
for k in modalities
for _ in range(rng.randint(limit))]
for k, limit in limit_mm_per_prompt.items()
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
@@ -135,21 +140,22 @@ def _test_processing_correctness(
# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}),
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
("fixie-ai/ultravox-v0_3", {"audio": True}),
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_3",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@@ -157,14 +163,12 @@ def _test_processing_correctness(
# yapf: enable
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
@@ -172,16 +176,13 @@ def test_processing_correctness(
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
])
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@@ -195,7 +196,6 @@ def test_processing_correctness_phi3v(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,