[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptInsertion, PromptReplacement,
|
||||
apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
iter_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
@@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
{
|
||||
"pattern_1": [],
|
||||
"pattern_2": [],
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000, 32000],
|
||||
@@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
def test_find_token_matches(
|
||||
prompt,
|
||||
target_by_key,
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_token_matches(prompt, prompt_repls)
|
||||
result = find_token_matches(prompt, prompt_updates)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
def test_find_text_matches(
|
||||
prompt,
|
||||
target_by_key,
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_text_matches(prompt, prompt_repls)
|
||||
result = find_text_matches(prompt, prompt_updates)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||
[
|
||||
(
|
||||
"Image:<image>Image:<image><image>!",
|
||||
@@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": "?!?",
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
1: "Image:<image><image><image>Image:<image><image>!?!?",
|
||||
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
1: "<image><image>Image:<image><image>?!?",
|
||||
2: "<image><image><image><image><image>?!?",
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, "Image:<image>Image:<image><image>!"),
|
||||
(1, "<image><image>Image:<image><image>?!?"),
|
||||
(2, "<image><image><image><image><image>?!?"),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_text(
|
||||
def test_find_update_text(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
mm_count,
|
||||
expected,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_text_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||
[
|
||||
# Tokenized test cases of `test_find_replace_text`
|
||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||
@@ -372,53 +392,61 @@ def test_find_replace_text(
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": [1550, 918, 1550],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
|
||||
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
|
||||
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
|
||||
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_tokens(
|
||||
def test_find_update_tokens(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
mm_count,
|
||||
expected,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@@ -524,22 +552,24 @@ def test_find_replace_tokens(
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_mm_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
expected,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
|
||||
mm_prompt_updates = {
|
||||
key: [update_type(key, [], repl).bind(mock_tokenizer)]
|
||||
for key, repl in repl_by_key.items()
|
||||
}
|
||||
|
||||
result = find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
mm_prompt_updates,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
|
||||
Reference in New Issue
Block a user