diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index d520be61d..f55bad569 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -355,7 +355,11 @@ class MultiModalProfiler(Generic[_I]): mm_counts=mm_counts, ) if max_tokens_per_item is not None: - return max_tokens_per_item + return { + modality: max_tokens + for modality, max_tokens in max_tokens_per_item.items() + if mm_counts.get(modality, 0) > 0 + } mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) @@ -375,5 +379,4 @@ class MultiModalProfiler(Generic[_I]): This is important to take into account when profiling and initializing the encoder cache size. """ - return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 2e4031bd5..8f9276e84 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -152,6 +152,7 @@ class MultiModalRegistry: model_config: "ModelConfig", *, cache: BaseMultiModalProcessorCache | None = None, + profiler_limits: Mapping[str, int] | None = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -164,40 +165,15 @@ class MultiModalRegistry: profiler: MultiModalProfiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + profiler_limits = ( + profiler.get_mm_limits() if profiler_limits is None else profiler_limits + ) return profiler.get_mm_max_contiguous_tokens( seq_len, - {modality: 1 for modality, limit in mm_limits.items() if limit > 0}, + {modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, ) - def get_max_tokens_per_item_by_nonzero_modality( - self, - model_config: "ModelConfig", - *, - cache: BaseMultiModalProcessorCache | None = None, - ) -> Mapping[str, int]: - """ - Get the maximum number of tokens per data item from each modality based - on underlying model configuration, excluding modalities that user - explicitly disabled via `limit_mm_per_prompt`. - - Note: - This is currently directly used only in V1 for profiling the memory - usage of a model. - """ - mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_per_item = self.get_max_tokens_per_item_by_modality( - model_config, - cache=cache, - ) - - return { - key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in max_tokens_per_item.items() - if mm_limits[key] > 0 - } - def get_mm_limits_per_prompt( self, model_config: "ModelConfig", @@ -369,7 +345,7 @@ class MultiModalRegistry: """ if not model_config.is_encoder_decoder: return 0 - max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens = self.get_max_tokens_per_item_by_modality(model_config) if not max_tokens: # TODO - this function assumes encoder-decoder models are # multimodal. This will need to change when adding support for more diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index c70025992..3959e9a59 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -264,8 +264,8 @@ def compute_encoder_budget( from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = ( - mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config) + max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( + model_config ) return compute_mm_encoder_budget( diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index f384ede06..92baf0cb7 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -42,10 +42,10 @@ class MultiModalBudget: self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) - max_tokens_by_modality = ( - mm_registry.get_max_tokens_per_item_by_nonzero_modality( - model_config, cache=cache - ) + max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + profiler_limits=self.mm_limits, ) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(