diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index f7f1608ec..5aa7211fe 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -35,6 +35,7 @@ from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import ( @@ -68,6 +69,17 @@ class InputProcessor: self.mm_registry = mm_registry self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) + self.mm_encoder_cache_size = None + if ( + self.mm_registry.supports_multimodal_inputs(self.model_config) + and not self.model_config.skip_tokenizer_init + ): + max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( + self.model_config + ) + _, self.mm_encoder_cache_size = compute_mm_encoder_budget( + self.vllm_config.scheduler_config, max_tokens_by_modality + ) self.input_preprocessor = InputPreprocessor( self.model_config, @@ -743,6 +755,25 @@ class InputProcessor: f"model length of {max_prompt_len}. {suggestion}" ) + if ( + prompt_type == "decoder" + and prompt_inputs["type"] == "multimodal" + and self.mm_encoder_cache_size is not None + ): + decoder_mm_positions = prompt_inputs["mm_placeholders"] + for modality, mm_positions in decoder_mm_positions.items(): + for mm_position in mm_positions: + embed_length = mm_position.get_num_embeds + if embed_length > self.mm_encoder_cache_size: + raise ValueError( + f"The {prompt_type} prompt contains a(n) {modality} item " + f"with length {embed_length}, which exceeds the " + f"pre-allocated encoder cache size " + f"{self.mm_encoder_cache_size}. Please reduce the input " + f"size or increase the encoder cache size " + f"by setting --limit-mm-per-prompt at startup." + ) + def stat_mm_cache(self) -> MultiModalCacheStats | None: return self.input_preprocessor.stat_mm_cache()