[Misc] Clean up MiniCPM-V/O code (#15337)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -29,7 +28,7 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -145,7 +144,7 @@ def _test_processing_correctness_hf(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
@@ -167,11 +166,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
cached_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
@@ -179,11 +179,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
@@ -191,11 +192,12 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
cached_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_correctness_mistral(
|
||||
@@ -206,7 +208,7 @@ def _test_processing_correctness_mistral(
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
images = mm_data.get("image", [])
|
||||
if not isinstance(images, list):
|
||||
@@ -233,11 +235,12 @@ def _test_processing_correctness_mistral(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
_assert_inputs_equal(
|
||||
baseline_tokenized_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@@ -261,6 +264,7 @@ def _test_processing_correctness_mistral(
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-Llama3-V-2_5",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"allenai/Molmo-7B-D-0924",
|
||||
@@ -290,7 +294,7 @@ def test_processing_correctness(
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = ['audio_features']
|
||||
ignore_mm_keys = {"audio_features"}
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
@@ -328,38 +332,26 @@ def test_processing_correctness_phi3v(
|
||||
)
|
||||
|
||||
|
||||
def _inputs_equal(
|
||||
def _assert_inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
*,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
msg: str = "",
|
||||
):
|
||||
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
b, ignore_mm_keys)
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
if msg is None:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b
|
||||
else:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
|
||||
def _drop_mm_kwargs_keys(
|
||||
result: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
for key in ignore_mm_keys:
|
||||
a["mm_kwargs"].pop(key, None)
|
||||
b["mm_kwargs"].pop(key, None)
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Args:
|
||||
result: Result to drop keys from
|
||||
ignore_mm_keys: List of keys to ignore, e.g. ['audio_features']
|
||||
"""
|
||||
if not ignore_mm_keys:
|
||||
return result
|
||||
|
||||
if 'mm_kwargs' in result:
|
||||
result = copy.deepcopy(result)
|
||||
mm_kwargs = result['mm_kwargs']
|
||||
for key in ignore_mm_keys:
|
||||
mm_kwargs.pop(key, None)
|
||||
for items in mm_kwargs._items_by_modality.values():
|
||||
for item in items:
|
||||
for key in ignore_mm_keys:
|
||||
item.pop(key, None)
|
||||
|
||||
return result
|
||||
if msg is None:
|
||||
assert a == b
|
||||
else:
|
||||
assert a == b, msg
|
||||
|
||||
Reference in New Issue
Block a user