[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement,
PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
iter_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"pattern_1": [],
"pattern_2": [],
}
},
),
(
[32000, 32000, 32000, 32000],
@@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
def test_find_token_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
result = find_token_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
@@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
def test_find_text_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
result = find_text_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
@@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
(
"Image:<image>Image:<image><image>!",
@@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
{
PromptInsertion: {
0: "Image:<image>Image:<image><image>!",
1: "Image:<image><image><image>Image:<image><image>!?!?",
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
},
PromptReplacement: {
0: "Image:<image>Image:<image><image>!",
1: "<image><image>Image:<image><image>?!?",
2: "<image><image><image><image><image>?!?",
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
def test_find_replace_text(
def test_find_update_text(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("mm_matches:", mm_matches)
print("result:", result)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# Manually constructed results
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
# Tokenized test cases of `test_find_replace_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
@@ -372,53 +392,61 @@ def test_find_replace_text(
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
{
PromptInsertion: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
},
PromptReplacement: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
def test_find_replace_tokens(
def test_find_update_tokens(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("mm_matches:", mm_matches)
print("result:", result)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# Manually constructed results
assert result == expected
# yapf: disable
@@ -524,22 +552,24 @@ def test_find_replace_tokens(
),
]
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_mm_placeholders(
repl_by_key,
prompt,
expected,
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(
mm_prompt_repls,
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3