diff --git a/tests/models/multimodal/processing/test_mllama4.py b/tests/models/multimodal/processing/test_mllama4.py index b54ce39a9..1b9b00732 100644 --- a/tests/models/multimodal/processing/test_mllama4.py +++ b/tests/models/multimodal/processing/test_mllama4.py @@ -24,32 +24,20 @@ def test_profiling(model_id: str, max_model_len: int): limit_mm_per_prompt=mm_counts, ) - processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) - decoder_dummy_data = processor.dummy_inputs.get_decoder_dummy_data( - processor, - max_model_len, - mm_counts=mm_counts, - ) - dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs( - max_model_len, + mm_inputs = MULTIMODAL_REGISTRY.get_dummy_mm_inputs( + ctx.model_config, mm_counts=mm_counts, ) hf_config = ctx.get_hf_config(Llama4Config) - - mm_inputs = processor.apply( - prompt=dummy_mm_data.prompt, - mm_data=dummy_mm_data.mm_data, - hf_processor_mm_kwargs=dict(), - ) - mm_data = mm_inputs["mm_kwargs"].get_data() - image_size = hf_config.vision_config.image_size patch_size = hf_config.vision_config.patch_size downsample_ratio = int( round(1.0 / (hf_config.vision_config.pixel_shuffle_ratio**2)) ) tokens_per_patch = ((image_size // patch_size) ** 2) // downsample_ratio + + mm_data = mm_inputs["mm_kwargs"].get_data() chunks_per_image = prod(mm_data["patches_per_image"]) total_num_patches = chunks_per_image * tokens_per_patch num_tiles = ( @@ -63,6 +51,5 @@ def test_profiling(model_id: str, max_model_len: int): item.get_num_embeds for item in mm_inputs["mm_placeholders"]["image"] ) assert total_tokens == sum( - placeholder.length - for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] + placeholder.length for placeholder in mm_inputs["mm_placeholders"]["image"] ) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index b50ab289b..7af0815d0 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -926,10 +926,10 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") with exc_ctx: - processor.dummy_inputs.get_decoder_dummy_data( - processor, - model_config.max_model_len, + MULTIMODAL_REGISTRY.get_dummy_mm_inputs( + model_config, mm_counts=limit_mm_per_prompt, + processor=processor, ) diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 3ef445f07..f32b938af 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -50,6 +50,7 @@ from .parse import ( MultiModalDataItems, MultiModalDataParser, ) +from .profiling import BaseDummyInputsBuilder if TYPE_CHECKING: from transformers.configuration_utils import PretrainedConfig @@ -59,7 +60,6 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, ObservabilityConfig from .cache import BaseMultiModalProcessorCache - from .profiling import BaseDummyInputsBuilder else: PretrainedConfig = object BatchFeature = object diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 6ef84278e..56f0dfc77 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass, field -from typing import Generic, NamedTuple, TypeVar +from typing import TYPE_CHECKING, Generic import numpy as np import numpy.typing as npt @@ -17,16 +17,14 @@ from vllm.config.multimodal import ( ) from vllm.logger import init_logger -from .inputs import ( - MultiModalDataDict, - MultiModalInputs, - MultiModalKwargsItems, - MultiModalPlaceholderDict, -) -from .processing import ( - BaseMultiModalProcessor, - BaseProcessingInfo, -) +from .inputs import MultiModalDataDict + +if TYPE_CHECKING: + from .processing import _I +else: + from typing import TypeVar + + _I = TypeVar("_I") logger = init_logger(__name__) @@ -44,23 +42,6 @@ class ProcessorInputs: tokenization_kwargs: Mapping[str, object] = field(default_factory=dict) -class DummyEncoderData(NamedTuple): - """Dummy data used for profiling.""" - - prompt_token_ids: list[int] - - -class DummyDecoderData(NamedTuple): - """Dummy data used for profiling.""" - - prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargsItems - multi_modal_placeholders: MultiModalPlaceholderDict - - -_I = TypeVar("_I", bound=BaseProcessingInfo) - - class BaseDummyInputsBuilder(ABC, Generic[_I]): """ Abstract base class that constructs the dummy data to profile @@ -222,52 +203,3 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): height = min(height, overrides.height) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) return [video] * num_videos - - def get_dummy_mm_inputs( - self, - processor: BaseMultiModalProcessor[_I], - seq_len: int, - mm_counts: Mapping[str, int] | None = None, - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> MultiModalInputs: - if mm_counts is None: - mm_counts = processor.allowed_mm_limits - - processor_inputs = self.get_dummy_processor_inputs( - seq_len, - mm_counts=mm_counts, - mm_options=mm_options, - ) - - return processor.apply( - prompt=processor_inputs.prompt, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - tokenization_kwargs=processor_inputs.tokenization_kwargs, - ) - - def get_decoder_dummy_data( - self, - processor: BaseMultiModalProcessor[_I], - seq_len: int, - mm_counts: Mapping[str, int] | None = None, - mm_options: Mapping[str, BaseDummyOptions] | None = None, - ) -> DummyDecoderData: - mm_inputs = self.get_dummy_mm_inputs( - processor, - seq_len, - mm_counts=mm_counts, - mm_options=mm_options, - ) - - prompt_token_ids = mm_inputs["prompt_token_ids"] - total_len = len(prompt_token_ids) - - if total_len < seq_len: - prompt_token_ids.extend([0] * (seq_len - total_len)) - - return DummyDecoderData( - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_inputs["mm_kwargs"].require_data(), - multi_modal_placeholders=mm_inputs["mm_placeholders"], - ) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 1e7c66e49..f696de6d3 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -10,15 +10,13 @@ from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from .cache import BaseMultiModalProcessorCache +from .inputs import MultiModalInputs from .processing import ( BaseMultiModalProcessor, BaseProcessingInfo, InputProcessingContext, ) -from .profiling import ( - BaseDummyInputsBuilder, - DummyDecoderData, -) +from .profiling import BaseDummyInputsBuilder if TYPE_CHECKING: from vllm.config import ModelConfig, ObservabilityConfig @@ -160,7 +158,6 @@ class MultiModalRegistry: model_config, observability_config, cache=cache ) - seq_len = model_config.max_model_len if profiler_limits is None: profiler_limits = processor.allowed_mm_limits @@ -169,7 +166,7 @@ class MultiModalRegistry: } max_tokens_per_item = processor.info.get_mm_max_tokens_per_item( - seq_len=seq_len, + seq_len=model_config.max_model_len, mm_counts=mm_counts, ) if max_tokens_per_item is not None: @@ -179,11 +176,10 @@ class MultiModalRegistry: if mm_counts.get(modality, 0) > 0 } - mm_inputs = processor.dummy_inputs.get_dummy_mm_inputs( - processor, - seq_len, + mm_inputs = self.get_dummy_mm_inputs( + model_config, mm_counts=mm_counts, - mm_options=self._extract_mm_options(model_config), + processor=processor, ) return { @@ -298,39 +294,47 @@ class MultiModalRegistry: return factories.build_processor(ctx, cache=cache) - def get_decoder_dummy_data( + def get_dummy_mm_inputs( self, model_config: "ModelConfig", - seq_len: int, mm_counts: Mapping[str, int] | None = None, *, cache: BaseMultiModalProcessorCache | None = None, observability_config: ObservabilityConfig | None = None, - ) -> DummyDecoderData: + processor: BaseMultiModalProcessor | None = None, + ) -> MultiModalInputs: """ Create dummy data for profiling the memory usage of a model. The model is identified by `model_config`. """ - processor = self.create_processor( - model_config, observability_config, cache=cache - ) - dummy_data = processor.dummy_inputs.get_decoder_dummy_data( - processor, - seq_len, + seq_len = model_config.max_model_len + + if processor is None: + processor = self.create_processor( + model_config, observability_config, cache=cache + ) + if mm_counts is None: + mm_counts = processor.allowed_mm_limits + + processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( + seq_len=seq_len, mm_counts=mm_counts, mm_options=self._extract_mm_options(model_config), ) + mm_inputs = processor.apply( + prompt=processor_inputs.prompt, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + tokenization_kwargs=processor_inputs.tokenization_kwargs, + ) - # Having more tokens is over-conservative but otherwise fine - token_ids = dummy_data.prompt_token_ids - if len(token_ids) < seq_len: - raise AssertionError( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(token_ids)} tokens instead." - ) + prompt_token_ids = mm_inputs["prompt_token_ids"] + total_len = len(prompt_token_ids) + if total_len < seq_len: + prompt_token_ids.extend([0] * (seq_len - total_len)) - return dummy_data + return mm_inputs def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c20dbbbd4..525cfede1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4192,16 +4192,18 @@ class GPUModelRunner( """Dummy data for profiling and precompiling multimodal models.""" assert self.mm_budget is not None - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_model_len, + # Don't use `max_items_per_batch` here to avoid redundant computation + dummy_mm_inputs = self.mm_registry.get_dummy_mm_inputs( + self.model_config, mm_counts={modality: 1}, cache=self.mm_budget.cache, ) - dummy_mm_data = dummy_decoder_data.multi_modal_data + dummy_mm_item = dummy_mm_inputs["mm_kwargs"][modality][0] + + # We use the cache so that the item is saved to the cache, + # but not read from the cache + assert dummy_mm_item is not None, "Item should not already be cached" - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_items = [dummy_mm_item] * max_items_per_batch return next(