[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,8 +11,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
_PlaceholderInfo, find_text_matches,
|
||||
find_token_matches, iter_placeholders,
|
||||
_PlaceholderInfo, find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
@@ -314,21 +314,27 @@ def test_find_replace_text(
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_text_matches(prompt, prompt_repls)
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
@@ -380,21 +386,27 @@ def test_find_replace_tokens(
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_token_matches(prompt, prompt_repls)
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("matches:", matches)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
@@ -417,58 +429,76 @@ def test_find_replace_tokens(
|
||||
[
|
||||
(
|
||||
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=5,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=7,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
{
|
||||
"pattern_1": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=0,
|
||||
start_idx=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
item_idx=1,
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
"pattern_3": [
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
item_idx=0,
|
||||
start_idx=6,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
}
|
||||
),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_iter_placeholders(
|
||||
def test_find_mm_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
expected,
|
||||
@@ -476,19 +506,18 @@ def test_iter_placeholders(
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, [], repl).bind(mock_tokenizer)
|
||||
mm_prompt_repls = {
|
||||
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
|
||||
for key, repl in repl_by_key.items()
|
||||
]
|
||||
}
|
||||
|
||||
result = list(
|
||||
iter_placeholders(
|
||||
prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
))
|
||||
result = find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@@ -694,7 +723,10 @@ def _test_processing_cache_correctness(
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text
|
||||
prompt = baseline_processor._get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt_text
|
||||
|
||||
# Drop unnecessary keys and test single -> multi conversion
|
||||
if rng.rand() < simplify_rate:
|
||||
@@ -728,6 +760,8 @@ def _test_processing_cache_correctness(
|
||||
("adept/fuyu-8b", {"image": False}),
|
||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
||||
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
||||
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
||||
("mistral-community/pixtral-12b", {"image": True}),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
||||
|
||||
Reference in New Issue
Block a user