[VLM] Enable tokenized inputs for merged multi-modal processor (#11900)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-10 11:24:00 +08:00
committed by GitHub
parent c3cf54dda4
commit b844b99ad3
12 changed files with 207 additions and 77 deletions

View File

@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
)
def _test_processing_cache_correctness(
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
@@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
tokenizer = baseline_processor.info.get_tokenizer()
rng = np.random.RandomState(0)
@@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
)
assert baseline_result == cached_result, (
f"Failed ({batch_idx=}, {mm_data=})")
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
baseline_tokenized_result = baseline_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert baseline_result == baseline_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
cached_tokenized_result = cached_processor.apply(
tokenizer.encode(prompt),
mm_data=mm_data,
hf_processor_mm_kwargs={},
)
assert cached_result == cached_tokenized_result, (
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
# yapf: disable
@@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
@@ -795,7 +814,7 @@ def test_processing_cache_correctness(
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness_phi3v(
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
_test_processing_cache_correctness(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,