[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-04 19:40:53 +08:00
committed by GitHub
parent 300acb8347
commit eed11ebee9
31 changed files with 1104 additions and 973 deletions

View File

@@ -11,8 +11,8 @@ from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
_PlaceholderInfo, find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
@@ -314,21 +314,27 @@ def test_find_replace_text(
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)
}
mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_text_matches(
prompt,
matches,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("matches:", matches)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
@@ -380,21 +386,27 @@ def test_find_replace_tokens(
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)
}
mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_token_matches(
prompt,
matches,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("matches:", matches)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
@@ -417,58 +429,76 @@ def test_find_replace_tokens(
[
(
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
replacement=[32000, 32000],
),
],
{
"pattern_1": [
_PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=6,
replacement=[32000, 32000],
),
],
}
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
replacement=[1550, 918, 1550],
),
],
{
"pattern_1": [
_PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=5,
replacement=[32000, 32000],
),
],
"pattern_3": [
_PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
replacement=[1550, 918, 1550],
),
],
}
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
replacement=[1550, 918, 1550],
),
],
{
"pattern_1": [
_PlaceholderInfo(
modality="pattern_1",
item_idx=0,
start_idx=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
item_idx=1,
start_idx=3,
replacement=[32000, 32000],
),
],
"pattern_3": [
_PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
replacement=[1550, 918, 1550],
),
],
}
),
]
)
# yapf: enable
def test_iter_placeholders(
def test_find_mm_placeholders(
repl_by_key,
prompt,
expected,
@@ -476,19 +506,18 @@ def test_iter_placeholders(
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, [], repl).bind(mock_tokenizer)
mm_prompt_repls = {
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items()
]
}
result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
))
result = find_mm_placeholders(
mm_prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
for key in repl_by_key},
)
# Only displayed on error
print("result:", result)
@@ -694,7 +723,10 @@ def _test_processing_cache_correctness(
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text
prompt = baseline_processor._get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
# Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate:
@@ -728,6 +760,8 @@ def _test_processing_cache_correctness(
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),