[VLM] Enable tokenized inputs for merged multi-modal processor (#11900)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -649,7 +649,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_cache_correctness(
|
||||
def _test_processing_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
@@ -691,6 +691,7 @@ def _test_processing_cache_correctness(
|
||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||
dummy_inputs = baseline_processor.dummy_inputs
|
||||
tokenizer = baseline_processor.info.get_tokenizer()
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
@@ -747,7 +748,25 @@ def _test_processing_cache_correctness(
|
||||
)
|
||||
|
||||
assert baseline_result == cached_result, (
|
||||
f"Failed ({batch_idx=}, {mm_data=})")
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
tokenizer.encode(prompt),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert baseline_result == baseline_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
tokenizer.encode(prompt),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert cached_result == cached_tokenized_result, (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@@ -771,14 +790,14 @@ def _test_processing_cache_correctness(
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_cache_correctness(
|
||||
def test_processing_correctness(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
_test_processing_cache_correctness(
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
@@ -795,7 +814,7 @@ def test_processing_cache_correctness(
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_cache_correctness_phi3v(
|
||||
def test_processing_correctness_phi3v(
|
||||
model_id: str,
|
||||
modalities: dict[str, bool],
|
||||
hit_rate: float,
|
||||
@@ -809,7 +828,7 @@ def test_processing_cache_correctness_phi3v(
|
||||
|
||||
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
_test_processing_cache_correctness(
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
modalities,
|
||||
hit_rate=hit_rate,
|
||||
|
||||
Reference in New Issue
Block a user