diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index b7348dc4a..bb1c8b478 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -176,7 +176,7 @@ def get_text_token_prompts( if model_type in MM_DATA_PATCHES: mm_data = MM_DATA_PATCHES[model_type](mm_data) - parsed_data = processor.data_parser.parse_mm_data(mm_data) + parsed_data = processor.info.parse_mm_data(mm_data) mm_counts = {k: len(vs) for k, vs in parsed_data.items()} text_prompt: str | None @@ -336,17 +336,18 @@ def _test_processing_correctness_one( model_type = model_config.hf_config.model_type text_prompt, token_prompt = get_text_token_prompts(baseline_processor, mm_data) + mm_items = baseline_processor.info.parse_mm_data(mm_data) ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) baseline_tokenized_result = baseline_processor.apply( token_prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs={}, ) cached_tokenized_result = cached_processor.apply( token_prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs={}, ) @@ -360,12 +361,12 @@ def _test_processing_correctness_one( if text_prompt is not None: baseline_text_result = baseline_processor.apply( text_prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs={}, ) cached_text_result = cached_processor.apply( text_prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs={}, ) diff --git a/tests/models/multimodal/processing/test_gemma3.py b/tests/models/multimodal/processing/test_gemma3.py index e252be894..5a3271e07 100644 --- a/tests/models/multimodal/processing/test_gemma3.py +++ b/tests/models/multimodal/processing/test_gemma3.py @@ -175,7 +175,11 @@ def test_get_image_size_with_most_features( for asset in image_assets: mm_data = {"image": [asset.pil_image]} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) mm_kwargs_data = processed_inputs["mm_kwargs"].get_data() num_patches_tensor = mm_kwargs_data["num_patches"] tokens = int(num_patches_tensor.item()) * image_seq_length diff --git a/tests/models/multimodal/processing/test_glm4_1v.py b/tests/models/multimodal/processing/test_glm4_1v.py index 51071c935..909020d15 100644 --- a/tests/models/multimodal/processing/test_glm4_1v.py +++ b/tests/models/multimodal/processing/test_glm4_1v.py @@ -52,7 +52,11 @@ def test_processor_override( metadata["fps"] = fps mm_data = {"video": [(video, metadata)]} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -100,8 +104,16 @@ def test_video_loader_consistency( static_mm_data = {"video": [(static_video, static_metadata)]} dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} - static_outputs = processor.apply(prompt, static_mm_data, hf_processor_mm_kwargs) - dynamic_outputs = processor.apply(prompt, dynamic_mm_data, hf_processor_mm_kwargs) + static_outputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(static_mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) + dynamic_outputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(dynamic_mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) assert static_outputs["prompt_token_ids"] == dynamic_outputs["prompt_token_ids"] assert batched_tensors_equal( diff --git a/tests/models/multimodal/processing/test_h2ovl.py b/tests/models/multimodal/processing/test_h2ovl.py index 1701d9dd8..7cbc4a284 100644 --- a/tests/models/multimodal/processing/test_h2ovl.py +++ b/tests/models/multimodal/processing/test_h2ovl.py @@ -106,7 +106,11 @@ def _run_check( for image in images ) - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=mm_processor_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("") diff --git a/tests/models/multimodal/processing/test_idefics3.py b/tests/models/multimodal/processing/test_idefics3.py index 351b9d018..d88d37f0b 100644 --- a/tests/models/multimodal/processing/test_idefics3.py +++ b/tests/models/multimodal/processing/test_idefics3.py @@ -55,7 +55,11 @@ def test_processor_override( dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) diff --git a/tests/models/multimodal/processing/test_internvl.py b/tests/models/multimodal/processing/test_internvl.py index b4994295d..a66095e9d 100644 --- a/tests/models/multimodal/processing/test_internvl.py +++ b/tests/models/multimodal/processing/test_internvl.py @@ -66,7 +66,11 @@ def _run_check( for image in images ) - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=mm_processor_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("") diff --git a/tests/models/multimodal/processing/test_llama4.py b/tests/models/multimodal/processing/test_llama4.py index b73246b68..721cf627d 100644 --- a/tests/models/multimodal/processing/test_llama4.py +++ b/tests/models/multimodal/processing/test_llama4.py @@ -49,7 +49,11 @@ def test_processor_override( if tokenized_prompt: prompt = tokenizer.encode(prompt) - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=mm_processor_kwargs, + ) mm_data = processed_inputs["mm_kwargs"].get_data() # place holder replacements diff --git a/tests/models/multimodal/processing/test_llava_next.py b/tests/models/multimodal/processing/test_llava_next.py index ffe7ca17b..23f37b973 100644 --- a/tests/models/multimodal/processing/test_llava_next.py +++ b/tests/models/multimodal/processing/test_llava_next.py @@ -87,7 +87,11 @@ def _validate_image_prompt_replacements_one( try: # The processor will throw an error if there is a mismatch # in the prompt replacements - processed_inputs = processor.apply(prompt, mm_data, {}) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs={}, + ) image_placeholders = processed_inputs["mm_placeholders"]["image"] assert len(image_placeholders) == num_imgs diff --git a/tests/models/multimodal/processing/test_llava_onevision.py b/tests/models/multimodal/processing/test_llava_onevision.py index f5c552fe6..2ded093ca 100644 --- a/tests/models/multimodal/processing/test_llava_onevision.py +++ b/tests/models/multimodal/processing/test_llava_onevision.py @@ -87,7 +87,11 @@ def _validate_image_prompt_replacements_one( try: # The processor will throw an error if there is a mismatch # in the prompt replacements - processed_inputs = processor.apply(prompt, mm_data, {}) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs={}, + ) image_placeholders = processed_inputs["mm_placeholders"]["image"] assert len(image_placeholders) == num_imgs diff --git a/tests/models/multimodal/processing/test_minimax_vl_01.py b/tests/models/multimodal/processing/test_minimax_vl_01.py index 11e000123..cdd491294 100644 --- a/tests/models/multimodal/processing/test_minimax_vl_01.py +++ b/tests/models/multimodal/processing/test_minimax_vl_01.py @@ -29,7 +29,11 @@ def test_processor_override( image = Image.new("RGB", size=(364, 364)) mm_data = {"image": [image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, {}) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs={}, + ) image_placeholders = processed_inputs["mm_placeholders"]["image"] assert len(image_placeholders) == num_imgs @@ -46,7 +50,11 @@ def _validate_image_prompt_replacements_one( mm_data = {"image": [image] * num_imgs} try: - processed_inputs = processor.apply(prompt, mm_data, {}) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs={}, + ) image_placeholders = processed_inputs["mm_placeholders"]["image"] assert len(image_placeholders) == num_imgs diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py index 5311ab1b7..99f9438e4 100644 --- a/tests/models/multimodal/processing/test_nemotron_vl.py +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -68,7 +68,11 @@ def _run_check( for image in images ) print(total_expected_num_patches) - processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=mm_processor_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size image_token_id = tokenizer.convert_tokens_to_ids("") diff --git a/tests/models/multimodal/processing/test_phi3v.py b/tests/models/multimodal/processing/test_phi3v.py index 8faff2611..c64426db6 100644 --- a/tests/models/multimodal/processing/test_phi3v.py +++ b/tests/models/multimodal/processing/test_phi3v.py @@ -47,7 +47,11 @@ def test_processor_override( prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) diff --git a/tests/models/multimodal/processing/test_phi4mm.py b/tests/models/multimodal/processing/test_phi4mm.py index 5391555c2..157bfd876 100644 --- a/tests/models/multimodal/processing/test_phi4mm.py +++ b/tests/models/multimodal/processing/test_phi4mm.py @@ -51,7 +51,11 @@ def test_processor_override( dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size img_tok_count = processed_inputs["prompt_token_ids"].count( diff --git a/tests/models/multimodal/processing/test_qwen2_vl.py b/tests/models/multimodal/processing/test_qwen2_vl.py index 20beaa601..a0ecce5d8 100644 --- a/tests/models/multimodal/processing/test_qwen2_vl.py +++ b/tests/models/multimodal/processing/test_qwen2_vl.py @@ -42,7 +42,11 @@ def test_processor_override( prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure we have the right number of placeholders per num_crops size hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -83,7 +87,11 @@ def test_get_image_size_with_most_features( prompt = "<|vision_start|><|image_pad|><|vision_end|>" for asset in image_assets: mm_data = {"image": [asset.pil_image]} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist() t, h, w = grid_thw[0] tokens = (t * h * w) // (merge_size**2) diff --git a/tests/models/multimodal/processing/test_qwen3_omni.py b/tests/models/multimodal/processing/test_qwen3_omni.py index d66283be4..05c0b5c61 100644 --- a/tests/models/multimodal/processing/test_qwen3_omni.py +++ b/tests/models/multimodal/processing/test_qwen3_omni.py @@ -51,7 +51,11 @@ def test_processor_with_audio_sample_rate( hf_processor_mm_kwargs: dict[str, Any] = { "audio_sample_rate": audio_sample_rate, } - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Verify audio tokens are generated hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) @@ -90,7 +94,11 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None: hf_processor_mm_kwargs: dict[str, Any] = { "audio_sample_rate": audio_sample_rate, } - processed = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) hf_proc = processor.info.get_hf_processor(**hf_processor_mm_kwargs) audio_token_id = tokenizer.convert_tokens_to_ids(hf_proc.audio_token) return processed["prompt_token_ids"].count(audio_token_id) diff --git a/tests/models/multimodal/processing/test_smolvlm.py b/tests/models/multimodal/processing/test_smolvlm.py index 6f77d5516..102563154 100644 --- a/tests/models/multimodal/processing/test_smolvlm.py +++ b/tests/models/multimodal/processing/test_smolvlm.py @@ -55,7 +55,11 @@ def test_processor_override( dummy_image = image_assets[0].pil_image.resize(dummy_image_size) mm_data = {"image": [dummy_image] * num_imgs} - processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs) + processed_inputs = processor.apply( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + ) # Ensure the placeholders format are correct hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs) diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 346649afd..8f7993647 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -24,10 +24,7 @@ from vllm.distributed import ( init_distributed_environment, initialize_model_parallel, ) -from vllm.model_executor.models.interfaces import ( - SupportsMultiModal, - supports_multimodal, -) +from vllm.model_executor.models.interfaces import supports_multimodal from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext from vllm.multimodal.utils import group_mm_kwargs_by_modality @@ -86,7 +83,6 @@ def resize_mm_data( def create_batched_mm_kwargs( - model_cls: type[SupportsMultiModal], model_config: ModelConfig, processor: BaseMultiModalProcessor, size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), @@ -102,10 +98,10 @@ def create_batched_mm_kwargs( seq_len=model_config.max_model_len, mm_counts=mm_counts, ) - mm_data = processor_inputs.mm_data + mm_items = processor_inputs.mm_items resized_mm_data = { - modality: resize_mm_data(data, size_factors) - for modality, data in mm_data.items() + modality: resize_mm_data(items.data, size_factors) + for modality, items in mm_items.items() } # video metadata will be added back to the resized video data here. @@ -113,7 +109,7 @@ def create_batched_mm_kwargs( mm_kwargs = processor.apply( prompt=token_prompt if text_prompt is None else text_prompt, - mm_data=resized_mm_data, + mm_items=processor.info.parse_mm_data(resized_mm_data), hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, )["mm_kwargs"].require_data() @@ -246,9 +242,7 @@ def test_model_tensor_schema(model_id: str): processor = factories.build_processor(ctx, cache=None) with initialize_dummy_model(model_cls, model_config) as model: - for modality, _, mm_kwargs in create_batched_mm_kwargs( - model_cls, model_config, processor - ): + for modality, _, mm_kwargs in create_batched_mm_kwargs(model_config, processor): for method_name in inputs_parse_methods: print( f"Testing `{method_name}` with modality={modality} " diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py index e2a2186f4..7d38c3c14 100644 --- a/tests/models/multimodal/processing/test_transformers.py +++ b/tests/models/multimodal/processing/test_transformers.py @@ -21,7 +21,7 @@ def test_multimodal_processor(model_id): str_prompt = "<|im_start|>user \nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 str_processed_inputs = mm_processor.apply( prompt=str_prompt, - mm_data=mm_data, + mm_items=mm_processor.info.parse_mm_data(mm_data), hf_processor_mm_kwargs={}, ) @@ -46,7 +46,7 @@ def test_multimodal_processor(model_id): ] ids_processed_inputs = mm_processor.apply( prompt=ids_prompt, - mm_data=mm_data, + mm_items=mm_processor.info.parse_mm_data(mm_data), hf_processor_mm_kwargs={}, ) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index d5ecbaf66..316234ba5 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -970,7 +970,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): with exc_ctx: processor.apply( "" * num_images, - mm_data=mm_data, + mm_items=processor.info.parse_mm_data(mm_data), hf_processor_mm_kwargs={}, ) diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index eb0a38f51..6edb26a4a 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -258,9 +258,10 @@ class InputPreprocessor: if mm_processor_kwargs is None: mm_processor_kwargs = {} + mm_items = mm_processor.info.parse_mm_data(mm_data) mm_input = mm_processor.apply( prompt, - mm_data, + mm_items, hf_processor_mm_kwargs=mm_processor_kwargs, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, diff --git a/vllm/model_executor/models/aya_vision.py b/vllm/model_executor/models/aya_vision.py index fa85d952b..af72f0bc4 100644 --- a/vllm/model_executor/models/aya_vision.py +++ b/vllm/model_executor/models/aya_vision.py @@ -227,9 +227,8 @@ class AyaVisionMultiModalProcessor(BaseMultiModalProcessor[AyaVisionProcessingIn # HF processor pops the `num_patches` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index a6a303348..4ffeedf46 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -201,20 +201,20 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: - if prompt and mm_data: + if prompt and mm_items: raise ValueError( "CLIP accepts text-only or image-only inputs, not both! " "Image-only inputs means passing an image with an empty text " "prompt." ) - if mm_data: + if mm_items: # For multi-modal data, the prompt after processing should # only contain the dummy image tokens tokenization_kwargs = { @@ -224,7 +224,7 @@ class CLIPMultiModalProcessor(BaseMultiModalProcessor[CLIPProcessingInfo]): return super().apply( prompt=prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, diff --git a/vllm/model_executor/models/cohere2_vision.py b/vllm/model_executor/models/cohere2_vision.py index f281a1d4b..ebdb4bcb8 100644 --- a/vllm/model_executor/models/cohere2_vision.py +++ b/vllm/model_executor/models/cohere2_vision.py @@ -262,9 +262,8 @@ class Cohere2VisionMultiModalProcessor( hf_processor = self.info.get_hf_processor(**mm_kwargs) # Fallback calculation if HF processor didn't provide num_patches - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) num_patches = [ self.info.get_num_patches( diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index e436d2981..18437528e 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -290,9 +290,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 7ea67d6b9..d51c50af0 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -349,9 +349,8 @@ class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo tok_kwargs, ) - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] diff --git a/vllm/model_executor/models/lfm2_vl.py b/vllm/model_executor/models/lfm2_vl.py index 532a2a913..445ecdce7 100644 --- a/vllm/model_executor/models/lfm2_vl.py +++ b/vllm/model_executor/models/lfm2_vl.py @@ -357,9 +357,8 @@ class Lfm2VLMultiModalProcessor(BaseMultiModalProcessor[Lfm2VLProcessingInfo]): tok_kwargs, ) - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) image_sizes = [ parsed_images.get_image_size(i) for i in range(len(parsed_images)) ] diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6afe64776..2f9aaa3f3 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -769,7 +769,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, mm_uuids: MultiModalUUIDDict | None = None, @@ -785,13 +785,12 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): result = super().apply( prompt, - mm_data, + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids, ) - mm_items = self._to_mm_items(mm_data) mm_item_counts = mm_items.get_all_counts() mm_kwargs = result["mm_kwargs"] mm_hashes = result["mm_hashes"] diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 4bf004106..a405d8eb4 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -300,7 +300,8 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing if (audios := mm_data.get("audios")) is None: return {} - parsed_audios = self.data_parser.parse_mm_data({"audio": audios}).get_items( + mm_items = self.info.parse_mm_data({"audio": audios}, validate=False) + parsed_audios = mm_items.get_items( "audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems) ) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 06998fc82..ebe2eca32 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -767,7 +767,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): if (images := mm_data.get("images")) is None: return {} - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items( "image", (MiniCPMVImageEmbeddingItems, ImageProcessorItems) ) @@ -793,7 +794,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): if (videos := mm_data.get("videos")) is None: return {} - parsed_videos = self.data_parser.parse_mm_data({"video": videos}).get_items( + mm_items = self.info.parse_mm_data({"video": videos}, validate=False) + parsed_videos = mm_items.get_items( "video", (MiniCPMVVideoEmbeddingItems, VideoProcessorItems) ) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index d59e444d0..54b58299b 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -609,9 +609,8 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]) ) images = mm_data["images"] - parsed_images = self.data_parser.parse_mm_data({"image": images}).get_items( - "image", ImageProcessorItems - ) + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) tile_size = vision_config.image_size possible_resolutions = find_supported_resolutions( diff --git a/vllm/model_executor/models/nemotron_parse.py b/vllm/model_executor/models/nemotron_parse.py index 5069bf239..f9acae3e0 100644 --- a/vllm/model_executor/models/nemotron_parse.py +++ b/vllm/model_executor/models/nemotron_parse.py @@ -660,7 +660,7 @@ class NemotronParseMultiModalProcessor( def create_encoder_prompt( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, ) -> str | list[int]: return [0] diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 9d67522e2..4ab0067f3 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -225,14 +225,14 @@ class PaliGemmaMultiModalProcessor(BaseMultiModalProcessor[PaliGemmaProcessingIn def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: mm_inputs = super().apply( prompt, - mm_data, + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 684c32a3a..3a5dee3c2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -303,9 +303,11 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]): res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens + dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) + return ProcessorInputs( prompt=dummy_tokens, - mm_data=dummy_mm_data, + mm_items=dummy_mm_items, tokenization_kwargs=tokenization_kwargs, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 4e63521bc..9f1bbd596 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -187,20 +187,20 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, *, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: - if prompt and mm_data: + if prompt and mm_items: raise ValueError( "Siglip accepts text-only or image-only inputs, not both! " "Image-only inputs means passing an image with an empty text " "prompt." ) - if mm_data: + if mm_items: # For multi-modal data, the prompt after processing should # only contain the image token tokenization_kwargs = { @@ -210,7 +210,7 @@ class SiglipMultiModalProcessor(BaseMultiModalProcessor[SiglipProcessingInfo]): return super().apply( prompt=prompt, - mm_data=mm_data, + mm_items=mm_items, hf_processor_mm_kwargs=hf_processor_mm_kwargs, tokenization_kwargs=tokenization_kwargs, mm_uuids=mm_uuids, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index 3565af74e..a4fc3a10b 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -180,20 +180,20 @@ class TerratorchMultiModalProcessor(BaseMultiModalProcessor[TerratorchProcessing def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: - mm_items = self._to_mm_items(mm_data) - tokenization_kwargs = tokenization_kwargs or {} + if tokenization_kwargs is None: + tokenization_kwargs = {} + mm_hashes = self._hash_mm_items( mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids ) - mm_processed_data = BatchFeature( - mm_data.get("image", mm_data), tensor_type="pt" - ) + _, passthrough_data = self._get_hf_mm_data(mm_items) + mm_processed_data = BatchFeature(dict(passthrough_data), tensor_type="pt") mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} mm_kwargs = MultiModalKwargsItems.from_hf_inputs( diff --git a/vllm/model_executor/models/transformers/multimodal.py b/vllm/model_executor/models/transformers/multimodal.py index 686649733..890b486b8 100644 --- a/vllm/model_executor/models/transformers/multimodal.py +++ b/vllm/model_executor/models/transformers/multimodal.py @@ -174,7 +174,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, mm_uuids: MultiModalUUIDDict | None = None, @@ -188,7 +188,6 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): if tokenization_kwargs is None: tokenization_kwargs = {} - mm_items = self._to_mm_items(mm_data) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) if not isinstance(prompt, str): # the prompt is the tokenized ids which is not supported diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index e604ba4ad..c187dba14 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -262,11 +262,14 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): ) res = tokenizer.mistral.encode_chat_completion(request) dummy_tokens = res.tokens - # whixtral tokenizer adds padding to the audio - # so we need to update the audio arrays - dummy_mm_data["audio"] = [a.audio_array for a in res.audios] - return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) + dummy_mm_inputs = self.info.parse_mm_data( + # whixtral tokenizer adds padding to the audio + # so we need to update the audio arrays + {**dummy_mm_data, "audio": [a.audio_array for a in res.audios]}, + ) + + return ProcessorInputs(prompt=dummy_tokens, mm_items=dummy_mm_inputs) class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]): diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index d9952ce43..58d24d0c9 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -705,7 +705,7 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo def create_encoder_prompt( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, ) -> str | list[int]: # Strictly speaking, whisper encoder only accept audio features. # We create a dummy encoder prompt here which will be padded to diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 23ffc2cd4..9a98692b5 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -14,7 +14,13 @@ import torch from typing_extensions import TypeVar from vllm.logger import init_logger -from vllm.multimodal.parse import MultiModalDataParser +from vllm.multimodal.inputs import MultiModalDataDict +from vllm.multimodal.parse import ( + DictEmbeddingItems, + EmbeddingItems, + MultiModalDataItems, + MultiModalDataParser, +) from vllm.tokenizers import TokenizerLike from vllm.transformers_utils.processor import cached_processor_from_config from vllm.utils.func_utils import get_allowed_kwarg_only_overrides @@ -596,6 +602,10 @@ class BaseProcessingInfo: expected_hidden_size=self._get_expected_hidden_size(), ) + @cached_property + def data_parser(self) -> MultiModalDataParser: + return self.get_data_parser() + @property def skip_prompt_length_check(self) -> bool: return False @@ -655,6 +665,36 @@ class BaseProcessingInfo: raise ValueError(msg) + def parse_mm_data( + self, + mm_data: MultiModalDataDict, + *, + validate: bool = True, + ) -> MultiModalDataItems: + """ + Normalize + [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] + to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems] + before passing them to + [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. + """ + mm_items = self.data_parser.parse_mm_data(mm_data) + + if validate: + mm_config = self.ctx.model_config.get_multimodal_config() + if not mm_config.enable_mm_embeds: + for modality, items in mm_items.items(): + if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): + raise ValueError( + f"You must set `--enable-mm-embeds` to input " + f"`{modality}_embeds`" + ) + + for modality, items in mm_items.items(): + self.validate_num_items(modality, len(items)) + + return mm_items + def get_mm_max_tokens_per_item( self, seq_len: int, diff --git a/vllm/multimodal/processing/dummy_inputs.py b/vllm/multimodal/processing/dummy_inputs.py index 05e76fe7a..b23e2b86c 100644 --- a/vllm/multimodal/processing/dummy_inputs.py +++ b/vllm/multimodal/processing/dummy_inputs.py @@ -18,6 +18,7 @@ from vllm.config.multimodal import ( from vllm.logger import init_logger from ..inputs import MultiModalDataDict +from ..parse import MultiModalDataItems from .context import BaseProcessingInfo _I = TypeVar("_I", bound=BaseProcessingInfo) @@ -33,7 +34,7 @@ class ProcessorInputs: """ prompt: str | list[int] - mm_data: MultiModalDataDict + mm_items: MultiModalDataItems hf_processor_mm_kwargs: Mapping[str, object] = field(default_factory=dict) tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) @@ -93,15 +94,14 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): mm_options: Configurable options per modality (optional) """ dummy_text = self.get_dummy_text(mm_counts) - - # Use the unified function for both legacy and configurable cases dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) + dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) tokenization_kwargs = {"truncation": False} return ProcessorInputs( prompt=dummy_text, - mm_data=dummy_mm_data, + mm_items=dummy_mm_items, tokenization_kwargs=tokenization_kwargs, ) diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index 643e781a2..dfce4dab2 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -25,7 +25,6 @@ from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from ..hasher import MultiModalHasher from ..inputs import ( - MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalHashes, @@ -1013,39 +1012,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def __call__( self, prompt: str, - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], *, mm_uuids: MultiModalUUIDDict | None = None, ) -> MultiModalInputs: - return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) - - def _to_mm_items( - self, - mm_data: MultiModalDataDict, - ) -> MultiModalDataItems: - """ - Normalize - [`MultiModalDataDict`][vllm.multimodal.inputs.MultiModalDataDict] - to [`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems] - before passing them to - [`_get_hf_mm_data`][vllm.multimodal.processing.BaseMultiModalProcessor._get_hf_mm_data]. - """ - mm_items = self.data_parser.parse_mm_data(mm_data) - - mm_config = self.info.ctx.model_config.get_multimodal_config() - if not mm_config.enable_mm_embeds: - for modality, items in mm_items.items(): - if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): - raise ValueError( - f"You must set `--enable-mm-embeds` to input " - f"`{modality}_embeds`" - ) - - for modality, items in mm_items.items(): - self.info.validate_num_items(modality, len(items)) - - return mm_items + return self.apply(prompt, mm_items, hf_processor_mm_kwargs, mm_uuids=mm_uuids) @abstractmethod def _get_mm_fields_config( @@ -1409,6 +1381,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ] for modality, items_is_cached in mm_is_cached.items() } + mm_missing_data = {} for modality, idxs in mm_missing_idxs.items(): missing_modality_data = [] @@ -1423,7 +1396,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_modality_data.append(data) mm_missing_data[modality] = missing_modality_data - return mm_is_cached, self._to_mm_items(mm_missing_data) + mm_missing_items = self.info.parse_mm_data(mm_missing_data) + + return mm_is_cached, mm_missing_items def _recompute_cached_prompt_update( self, @@ -1774,7 +1749,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, *, @@ -1797,8 +1772,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): if request_id is not None: self.info.ctx.create_timing_stats(request_id) - mm_items = self._to_mm_items(mm_data) - if tokenization_kwargs is None: tokenization_kwargs = {} @@ -1843,7 +1816,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): def create_encoder_prompt( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, ) -> str | list[int]: """ Create input prompt for the encoder. HF processor will be applied on @@ -1854,7 +1827,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): def create_decoder_prompt( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, ) -> str | list[int]: """Create input prompt for the decoder.""" return prompt @@ -1862,11 +1835,11 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): def _get_enc_dec_inputs( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, encoder_inputs: MultiModalInputs, ): tokenizer = self.info.get_tokenizer() - decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_data) + decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items) if isinstance(decoder_prompt_raw, str): decoder_prompt_ids = tokenizer.encode( decoder_prompt_raw, add_special_tokens=False @@ -1884,7 +1857,7 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): def apply( self, prompt: str | list[int], - mm_data: MultiModalDataDict, + mm_items: MultiModalDataItems, hf_processor_mm_kwargs: Mapping[str, object], tokenization_kwargs: Mapping[str, object] | None = None, *, @@ -1897,10 +1870,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): 2. Apply the HF processor on encoder prompt. 3. Copy the input prompt text as decoder prompt inputs. """ - encoder_prompt = self.create_encoder_prompt(prompt, mm_data) + encoder_prompt = self.create_encoder_prompt(prompt, mm_items) encoder_inputs = super().apply( encoder_prompt, - mm_data, + mm_items, hf_processor_mm_kwargs, tokenization_kwargs, mm_uuids=mm_uuids, @@ -1908,6 +1881,6 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): return self._get_enc_dec_inputs( prompt=prompt, - mm_data=mm_data, + mm_items=mm_items, encoder_inputs=encoder_inputs, ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 9ce4924cf..7fe68b36f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -330,7 +330,7 @@ class MultiModalRegistry: ) mm_inputs = processor.apply( prompt=processor_inputs.prompt, - mm_data=processor_inputs.mm_data, + mm_items=processor_inputs.mm_items, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, ) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 5aa7211fe..c3d86e819 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -212,7 +212,7 @@ class InputProcessor: def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems: mm_processor = self.input_preprocessor._get_mm_processor() - return mm_processor.data_parser.parse_mm_data(mm_data) + return mm_processor.info.parse_mm_data(mm_data) def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None: if isinstance(prompt, str):