diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 709f2f286..d5ecbaf66 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -921,7 +921,7 @@ def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): ) processor = MULTIMODAL_REGISTRY.create_processor(model_config) - processor._supported_mm_limits = {"image": num_supported} + processor.info.get_supported_mm_limits = lambda: {"image": num_supported} exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 5c8683cbd..35dbed006 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -528,7 +528,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): else: num_items = len(self._items_by_modality[original_modality]) + 1 - self.mm_processor.validate_num_items(input_modality, num_items) + self.mm_processor.info.validate_num_items(input_modality, num_items) # Track original modality for vision_chunk items if use_vision_chunk: diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index f11a3b98a..89bcff3f8 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -176,9 +176,7 @@ class LoRAModelManager: ) mm_budget = MultiModalBudget(vllm_config, mm_registry) - limit_per_prompt: int = max( - self.mm_processor_info.get_allowed_mm_limits().values() - ) + limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values()) num_encoder_tokens = self.model.get_num_mm_encoder_tokens( mm_budget.get_encoder_budget() ) diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 62ad1dc3e..23ffc2cd4 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -7,11 +7,8 @@ from abc import abstractmethod from collections.abc import Generator, Mapping from contextlib import contextmanager from dataclasses import dataclass, field -from typing import ( - TYPE_CHECKING, - Any, - overload, -) +from functools import cached_property +from typing import TYPE_CHECKING, Any, overload import torch from typing_extensions import TypeVar @@ -615,13 +612,18 @@ class BaseProcessingInfo: """ raise NotImplementedError - def get_allowed_mm_limits(self) -> Mapping[str, int]: - """Return the maximum allowed number of items for each modality.""" - supported_mm_limits = self.get_supported_mm_limits() + @cached_property + def supported_mm_limits(self) -> Mapping[str, int | None]: + """The maximum supported number of items for each modality.""" + return self.get_supported_mm_limits() + + @cached_property + def allowed_mm_limits(self) -> Mapping[str, int]: + """The maximum allowed number of items for each modality.""" mm_config = self.ctx.get_mm_config() allowed_limits = dict[str, int]() - for modality, supported_limit in supported_mm_limits.items(): + for modality, supported_limit in self.supported_mm_limits.items(): user_limit = mm_config.get_limit_per_prompt(modality) allowed_limits[modality] = ( @@ -632,6 +634,27 @@ class BaseProcessingInfo: return allowed_limits + def validate_num_items(self, modality: str, num_items: int) -> None: + """ + Raise `ValueError` if the number of input items for the given modality + is invalid. + """ + supported_limit = self.supported_mm_limits.get(modality, 0) + allowed_limit = self.allowed_mm_limits.get(modality, 0) + + if supported_limit is None: + supported_limit = allowed_limit + + limit = min(supported_limit, allowed_limit) + + if num_items > limit: + msg = f"At most {limit} {modality}(s) may be provided in one prompt." + + if num_items <= supported_limit: + msg += " Set `--limit-mm-per-prompt` to increase this limit." + + raise ValueError(msg) + def get_mm_max_tokens_per_item( self, seq_len: int, diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index c2776f7f8..643e781a2 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -17,7 +17,7 @@ from typing import ( import regex as re import torch -from typing_extensions import TypeVar, assert_never +from typing_extensions import TypeVar, assert_never, deprecated from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike @@ -1000,17 +1000,15 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): else: self.data_parser = self.info.get_data_parser() - # Avoid unnecessary recomputation - self._supported_mm_limits = self.info.get_supported_mm_limits() - self._allowed_mm_limits = self.info.get_allowed_mm_limits() - @property + @deprecated("Will be removed in v0.17. Use `info.supported_mm_limits` instead.") def supported_mm_limits(self): - return self._supported_mm_limits + return self.info.supported_mm_limits @property + @deprecated("Will be removed in v0.17. Use `info.allowed_mm_limits` instead.") def allowed_mm_limits(self): - return self._allowed_mm_limits + return self.info.allowed_mm_limits def __call__( self, @@ -1022,27 +1020,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) -> MultiModalInputs: return self.apply(prompt, mm_data, hf_processor_mm_kwargs, mm_uuids=mm_uuids) - def validate_num_items( - self, - modality: str, - num_items: int, - ) -> None: - supported_limit = self.supported_mm_limits.get(modality, 0) - allowed_limit = self.allowed_mm_limits.get(modality, 0) - - if supported_limit is None: - supported_limit = allowed_limit - - limit = min(supported_limit, allowed_limit) - - if num_items > limit: - msg = f"At most {limit} {modality}(s) may be provided in one prompt." - - if num_items <= supported_limit: - msg += " Set `--limit-mm-per-prompt` to increase this limit." - - raise ValueError(msg) - def _to_mm_items( self, mm_data: MultiModalDataDict, @@ -1066,7 +1043,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ) for modality, items in mm_items.items(): - self.validate_num_items(modality, len(items)) + self.info.validate_num_items(modality, len(items)) return mm_items diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 117279369..9ce4924cf 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -168,7 +168,7 @@ class MultiModalRegistry: ) if profiler_limits is None: - profiler_limits = processor.allowed_mm_limits + profiler_limits = processor.info.allowed_mm_limits mm_counts = { modality: 1 for modality, limit in profiler_limits.items() if limit > 0 @@ -200,7 +200,6 @@ class MultiModalRegistry: self, model_config: "ModelConfig", *, - cache: BaseMultiModalProcessorCache | None = None, observability_config: ObservabilityConfig | None = None, ) -> Mapping[str, int]: """ @@ -210,10 +209,8 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor( - model_config, observability_config, cache=cache - ) - return processor.allowed_mm_limits + info = self._create_processing_info(model_config, observability_config) + return info.allowed_mm_limits def register_processor( self, @@ -324,7 +321,7 @@ class MultiModalRegistry: model_config, observability_config, cache=cache ) if mm_counts is None: - mm_counts = processor.allowed_mm_limits + mm_counts = processor.info.allowed_mm_limits processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( seq_len=seq_len, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 8af17e270..6c217bab8 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -40,7 +40,7 @@ class MultiModalBudget: self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( model_config,