diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index 325159965..b54ce39a9 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -7,7 +7,6 @@ from torch import prod from transformers import Llama4Config from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.profiling import MultiModalProfiler from ...utils import build_model_context @@ -26,9 +25,8 @@ def test_profiling(model_id: str, max_model_len: int): ) processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - profiler = MultiModalProfiler(processor) - - decoder_dummy_data = profiler.get_decoder_dummy_data( + decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data( + processor, max_model_len, mm_counts=mm_counts, ) @@ -39,11 +37,12 @@ def test_profiling(model_id: str, max_model_len: int): hf_config = ctx.get_hf_config(Llama4Config) - mm_data = processor.apply( + mm_inputs = processor.apply( prompt=dummy_mm_data.prompt, mm_data=dummy_mm_data.mm_data, hf_processor_mm_kwargs=dict(), - )["mm_kwargs"].get_data() + ) + mm_data = mm_inputs["mm_kwargs"].get_data() image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size @@ -60,12 +59,9 @@ def test_profiling(model_id: str, max_model_len: int): total_num_patches.item() + num_tiles.item() + 3 ) # image start, image, image end - profiled_tokens = profiler.get_mm_max_tokens( - max_model_len, - mm_counts=mm_counts, + assert total_num_patches == sum( + item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"] ) - - assert total_num_patches == profiled_tokens["image"] assert total_tokens == sum( placeholder.length for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 64bb88960..b50ab289b 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -22,7 +22,6 @@ from vllm.multimodal.processing import ( iter_token_matches, replace_token_matches, ) -from vllm.multimodal.profiling import MultiModalProfiler from .utils import random_image @@ -924,12 +923,11 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): processor = MULTIMODAL_REGISTRY.create_processor(model_config) processor._supported_mm_limits = {"image": num_supported} - profiler = MultiModalProfiler(processor) - exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: - profiler.get_decoder_dummy_data( + processor.dummy_inputs.get_decoder_dummy_data( + processor, model_config.max_model_len, mm_counts=limit_mm_per_prompt, ) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 4209c0014..6ef84278e 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -223,70 +223,42 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) return [video] * num_videos - -class MultiModalProfiler(Generic[_I]): - """ - Contains code for running memory profiling for multi-modal models. - """ - - def __init__( + def get_dummy_mm_inputs( self, processor: BaseMultiModalProcessor[_I], - ) -> None: - super().__init__() - - self.processor = processor - - @property - def processing_info(self) -> BaseProcessingInfo: - return self.processor.info - - @property - def dummy_inputs(self) -> BaseDummyInputsBuilder[_I]: - return self.processor.dummy_inputs - - def get_mm_limits(self) -> Mapping[str, int]: - return self.processor.allowed_mm_limits - - def _get_dummy_mm_inputs( - self, seq_len: int, mm_counts: Mapping[str, int] | None = None, mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalInputs: if mm_counts is None: - mm_counts = self.get_mm_limits() + mm_counts = processor.allowed_mm_limits - factory = self.dummy_inputs - processor_inputs = factory.get_dummy_processor_inputs( - seq_len, mm_counts, mm_options + processor_inputs = self.get_dummy_processor_inputs( + seq_len, + mm_counts=mm_counts, + mm_options=mm_options, ) - return self.processor.apply( + return processor.apply( prompt=processor_inputs.prompt, mm_data=processor_inputs.mm_data, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs, ) - def _get_mm_num_tokens( - self, - mm_inputs: MultiModalInputs, - ) -> Mapping[str, int]: - placeholders_by_modality = mm_inputs["mm_placeholders"] - - return { - modality: sum(item.get_num_embeds for item in placeholders) - for modality, placeholders in placeholders_by_modality.items() - } - def get_decoder_dummy_data( self, + processor: BaseMultiModalProcessor[_I], seq_len: int, mm_counts: Mapping[str, int] | None = None, mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> DummyDecoderData: - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts, mm_options) + mm_inputs = self.get_dummy_mm_inputs( + processor, + seq_len, + mm_counts=mm_counts, + mm_options=mm_options, + ) prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) @@ -299,29 +271,3 @@ class MultiModalProfiler(Generic[_I]): multi_modal_data=mm_inputs["mm_kwargs"].require_data(), multi_modal_placeholders=mm_inputs["mm_placeholders"], ) - - def get_mm_max_tokens( - self, - seq_len: int, - mm_counts: Mapping[str, int] | None = None, - ) -> Mapping[str, int]: - """ - Returns the maximum number of embeddings per item of each modality, excluding - any break/text tokens in-between multimodal embeddings/encoder outputs. - """ - if mm_counts is None: - mm_counts = self.get_mm_limits() - - max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item( - seq_len=seq_len, - mm_counts=mm_counts, - ) - if max_tokens_per_item is not None: - return { - modality: max_tokens - for modality, max_tokens in max_tokens_per_item.items() - if mm_counts.get(modality, 0) > 0 - } - - mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) - return self._get_mm_num_tokens(mm_inputs) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index d70c2f614..1e7c66e49 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -18,7 +18,6 @@ from .processing import ( from .profiling import ( BaseDummyInputsBuilder, DummyDecoderData, - MultiModalProfiler, ) if TYPE_CHECKING: @@ -160,17 +159,37 @@ class MultiModalRegistry: processor = self.create_processor( model_config, observability_config, cache=cache ) - profiler: MultiModalProfiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - profiler_limits = ( - profiler.get_mm_limits() if profiler_limits is None else profiler_limits + if profiler_limits is None: + profiler_limits = processor.allowed_mm_limits + + mm_counts = { + modality: 1 for modality, limit in profiler_limits.items() if limit > 0 + } + + max_tokens_per_item = processor.info.get_mm_max_tokens_per_item( + seq_len=seq_len, + mm_counts=mm_counts, + ) + if max_tokens_per_item is not None: + return { + modality: max_tokens + for modality, max_tokens in max_tokens_per_item.items() + if mm_counts.get(modality, 0) > 0 + } + + mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs( + processor, + seq_len, + mm_counts=mm_counts, + mm_options=self._extract_mm_options(model_config), ) - return profiler.get_mm_max_tokens( - seq_len, - {modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, - ) + return { + modality: sum(item.get_num_embeds for item in placeholders) + for modality, placeholders in mm_inputs["mm_placeholders"].items() + } def get_mm_limits_per_prompt( self, @@ -189,8 +208,7 @@ class MultiModalRegistry: processor = self.create_processor( model_config, observability_config, cache=cache ) - profiler: MultiModalProfiler = MultiModalProfiler(processor) - return profiler.get_mm_limits() + return processor.allowed_mm_limits def register_processor( self, @@ -297,14 +315,12 @@ class MultiModalRegistry: processor = self.create_processor( model_config, observability_config, cache=cache ) - profiler: MultiModalProfiler = MultiModalProfiler(processor) - - # Extract configurable options from multimodal config. - # Only include modalities that use advanced option types so legacy - # count-only behavior remains unchanged. - mm_options = self._extract_mm_options(model_config) - - dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts, mm_options) + dummy_data = processor.dummy_inputs.get_decoder_dummy_data( + processor, + seq_len, + mm_counts=mm_counts, + mm_options=self._extract_mm_options(model_config), + ) # Having more tokens is over-conservative but otherwise fine token_ids = dummy_data.prompt_token_ids