diff --git a/tests/entrypoints/llm/test_mm_embeds_only.py b/tests/entrypoints/llm/test_mm_embeds_only.py new file mode 100644 index 000000000..13d0fd58b --- /dev/null +++ b/tests/entrypoints/llm/test_mm_embeds_only.py @@ -0,0 +1,67 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import weakref + +import pytest + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL = "llava-hf/llava-1.5-7b-hf" +PROMPT = "USER: \nDescribe this image briefly.\nASSISTANT:" +TEXT_ONLY_PROMPT = "USER: What is 2 + 2?\nASSISTANT:" + + +@pytest.fixture(scope="module") +def llm(): + """LLM with enable_mm_embeds=True and all modality limits zeroed out.""" + llm = LLM( + model=MODEL, + max_model_len=2048, + enforce_eager=True, + gpu_memory_utilization=0.8, + enable_mm_embeds=True, + limit_mm_per_prompt={"image": 0}, + ) + + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.skip_global_cleanup +def test_generate_with_embedding(llm: LLM): + """Pre-computed embedding produces tokens without hanging.""" + embedding = ImageAsset("stop_sign").image_embeds + outputs = llm.generate( + {"prompt": PROMPT, "multi_modal_data": {"image": embedding}}, + sampling_params=SamplingParams(max_tokens=32, temperature=0.0), + ) + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].text) > 0 + + +@pytest.mark.skip_global_cleanup +def test_raw_image_rejected(llm: LLM): + """Raw image input is still rejected when limit=0.""" + raw_image = ImageAsset("stop_sign").pil_image + with pytest.raises(ValueError, match=r"At most 0 image\(s\)"): + llm.generate( + {"prompt": PROMPT, "multi_modal_data": {"image": raw_image}}, + sampling_params=SamplingParams(max_tokens=16), + ) + + +@pytest.mark.skip_global_cleanup +def test_text_only_prompt(llm: LLM): + """Text-only prompts still work under this config.""" + outputs = llm.generate( + TEXT_ONLY_PROMPT, + sampling_params=SamplingParams(max_tokens=16, temperature=0.0), + ) + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].text) > 0 diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index 316234ba5..2ab20fe2c 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -899,40 +899,6 @@ def test_find_mm_placeholders( assert result == expected -@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) -@pytest.mark.parametrize( - ("limit", "num_supported", "is_valid"), - [ - (0, 0, True), - (0, 1, True), - (1, 0, False), - (1, 1, True), - (1, 2, True), - (2, 1, False), - (2, 2, True), - ], -) -def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid): - limit_mm_per_prompt = {"image": limit} - - model_config = ModelConfig( - model=model_id, - limit_mm_per_prompt=limit_mm_per_prompt, - ) - - processor = MULTIMODAL_REGISTRY.create_processor(model_config) - processor.info.get_supported_mm_limits = lambda: {"image": num_supported} - - exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") - - with exc_ctx: - MULTIMODAL_REGISTRY.get_dummy_mm_inputs( - model_config, - mm_counts=limit_mm_per_prompt, - processor=processor, - ) - - @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize( ("num_images", "limit", "is_valid"), @@ -975,6 +941,50 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ) +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize( + ("user_limit", "supported_limit"), + [ + (0, 0), + (0, 1), + (1, 0), # user wants 1, model supports 0 → capped to 0 + (1, 1), + (1, 2), + (2, 1), # user wants 2, model supports 1 → capped to 1 + (2, 2), + (5, 1), # large user limit, low model support → capped to 1 + (1, 5), + (10, 0), # large user limit, no model support → capped to 0 + ], +) +def test_budget_caps_prevent_dummy_input_validation_failure( + model_id, user_limit, supported_limit +): + limit_mm_per_prompt = {"image": user_limit} + + model_config = ModelConfig( + model=model_id, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + + processor = MULTIMODAL_REGISTRY.create_processor(model_config) + processor.info.get_supported_mm_limits = lambda: {"image": supported_limit} + + # This is what budget.py uses to derive mm_counts + allowed = processor.info.allowed_mm_limits + + assert allowed["image"] <= supported_limit, ( + f"allowed_mm_limits['image']={allowed['image']} exceeds " + f"supported_limit={supported_limit}" + ) + + assert allowed["image"] <= user_limit, ( + f"allowed_mm_limits['image']={allowed['image']} exceeds user_limit={user_limit}" + ) + + assert allowed["image"] == min(user_limit, supported_limit) + + class DummyProcessor: def __init__(self, a: int = 0, b: int = 0) -> None: super().__init__() diff --git a/vllm/config/multimodal.py b/vllm/config/multimodal.py index 48eea6f4e..30305e4be 100644 --- a/vllm/config/multimodal.py +++ b/vllm/config/multimodal.py @@ -76,6 +76,11 @@ class MultiModalConfig: for the OpenAI-compatible server, this refers to chat messages with content `"type": "*_embeds"`. + When enabled with `--limit-mm-per-prompt` set to 0 for a modality, + precomputed embeddings skip count validation for that modality, + saving memory by not loading encoder modules while still enabling + embeddings as an input. Limits greater than 0 still apply to embeddings. + WARNING: The vLLM engine may crash if incorrect shape of embeddings is passed. Only enable this flag for trusted users!""" media_io_kwargs: dict[str, dict[str, Any]] = Field(default_factory=dict) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 0077a897d..c48d7bea9 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -528,7 +528,17 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): else: num_items = len(self._items_by_modality[original_modality]) + 1 - self.mm_processor.info.validate_num_items(input_modality, num_items) + mm_config = self.model_config.multimodal_config + if ( + mm_config is not None + and mm_config.enable_mm_embeds + and mm_config.get_limit_per_prompt(input_modality) == 0 + and original_modality.endswith("_embeds") + ): + # Skip validation: embeddings bypass limit when enable_mm_embeds=True + pass + else: + self.mm_processor.info.validate_num_items(input_modality, num_items) # Track original modality for vision_chunk items if use_vision_chunk: diff --git a/vllm/multimodal/budget.py b/vllm/multimodal/budget.py index 1dddc82b1..821c9e9b5 100644 --- a/vllm/multimodal/budget.py +++ b/vllm/multimodal/budget.py @@ -3,11 +3,14 @@ from collections.abc import Mapping from vllm.config import ModelConfig, VllmConfig +from vllm.logger import init_logger from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.registry import MultiModalRegistry from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget +logger = init_logger(__name__) + def get_mm_max_toks_per_item( model_config: ModelConfig, @@ -59,11 +62,26 @@ class MultiModalBudget: processor = mm_registry.create_processor(model_config, cache=cache) self.cache = cache + mm_config = model_config.get_multimodal_config() + enable_mm_embeds = mm_config is not None and mm_config.enable_mm_embeds + + supported_mm_limits = processor.info.supported_mm_limits self.mm_limits = mm_limits = processor.info.allowed_mm_limits - active_modalities = { - modality for modality, limit in mm_limits.items() if limit > 0 + # Modalities that pass through the MM encoder tower + tower_modalities = { + modality + for modality in supported_mm_limits + if mm_limits.get(modality, 0) > 0 } + # Modalities that bypass the tower (pre-computed embeddings only) + embed_only_modalities = { + modality + for modality in supported_mm_limits + if enable_mm_embeds and mm_limits.get(modality, 0) == 0 + } + + active_modalities = tower_modalities | embed_only_modalities all_mm_max_toks_per_item = get_mm_max_toks_per_item( model_config, @@ -72,19 +90,32 @@ class MultiModalBudget: mm_counts=dict.fromkeys(active_modalities, 1), ) + if embed_only_modalities: + logger.info_once( + "enable_mm_embeds is True; modalities handled as embedding-only: %s", + tuple(embed_only_modalities), + ) + # Some models (e.g., Qwen3Omni with use_audio_in_video=True) share # placeholders between modalities, so not all active modalities will # have their own entry in the returned dict. We filter to only include # modalities that have independent placeholder tokens. - mm_max_toks_per_item = { + active_mm_max_toks_per_item = { modality: all_mm_max_toks_per_item[modality] for modality in active_modalities if modality in all_mm_max_toks_per_item } + tower_mm_max_toks_per_item = { + modality: active_mm_max_toks_per_item[modality] + for modality in tower_modalities + if modality in active_mm_max_toks_per_item + } + # Encoder budget is computed from all active modalities (including + # embedding-only ones that need encoder cache space). encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, - mm_max_toks_per_item, + active_mm_max_toks_per_item, ) self.encoder_compute_budget = encoder_compute_budget @@ -93,13 +124,15 @@ class MultiModalBudget: mm_max_items_per_prompt = dict[str, int]() mm_max_items_per_batch = dict[str, int]() - for modality, max_toks_per_item in mm_max_toks_per_item.items(): + # Per-prompt/per-batch limits are only relevant for tower modalities + # (embedding-only modalities don't go through the encoder tower). + for modality, max_toks_per_item in tower_mm_max_toks_per_item.items(): ( mm_max_items_per_prompt[modality], mm_max_items_per_batch[modality], ) = self._get_max_items(modality, max_toks_per_item) - self.mm_max_toks_per_item = mm_max_toks_per_item + self.mm_max_toks_per_item = tower_mm_max_toks_per_item self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 9a98692b5..d5c14310c 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -681,16 +681,22 @@ class BaseProcessingInfo: mm_items = self.data_parser.parse_mm_data(mm_data) if validate: - mm_config = self.ctx.model_config.get_multimodal_config() - if not mm_config.enable_mm_embeds: - for modality, items in mm_items.items(): - if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): + mm_config = self.ctx.get_mm_config() + + for modality, items in mm_items.items(): + if isinstance(items, (EmbeddingItems, DictEmbeddingItems)): + if not mm_config.enable_mm_embeds: raise ValueError( f"You must set `--enable-mm-embeds` to input " f"`{modality}_embeds`" ) - - for modality, items in mm_items.items(): + if mm_config.get_limit_per_prompt(modality) == 0: + logger.debug( + "Skipping count validation for modality " + "'%s' (embeddings with limit=0)", + modality, + ) + continue self.validate_num_items(modality, len(items)) return mm_items diff --git a/vllm/multimodal/processing/dummy_inputs.py b/vllm/multimodal/processing/dummy_inputs.py index b23e2b86c..a93fd2c24 100644 --- a/vllm/multimodal/processing/dummy_inputs.py +++ b/vllm/multimodal/processing/dummy_inputs.py @@ -95,7 +95,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]): """ dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) - dummy_mm_items = self.info.parse_mm_data(dummy_mm_data) + dummy_mm_items = self.info.parse_mm_data(dummy_mm_data, validate=False) tokenization_kwargs = {"truncation": False} diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index fe697c5ce..5f98cce3d 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -1395,7 +1395,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): missing_modality_data.append(data) mm_missing_data[modality] = missing_modality_data - mm_missing_items = self.info.parse_mm_data(mm_missing_data) + mm_missing_items = self.info.parse_mm_data(mm_missing_data, validate=False) return mm_is_cached, mm_missing_items diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 82329d17f..6c7e86a4f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,6 +138,11 @@ class MultiModalRegistry: mm_config.get_limit_per_prompt(modality) == 0 for modality in info.supported_mm_limits ): + # If enable_mm_embeds is True, we still need MM infrastructure + # to process pre-computed embeddings even though encoder won't run + if mm_config.enable_mm_embeds: + return True + logger.info_once( "All limits of multimodal modalities supported by the model " "are set to 0, running in text-only mode." diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2ca4866d9..1dbf96090 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1259,6 +1259,9 @@ class GPUModelRunner( mm_budget = self.mm_budget assert mm_budget is not None + if not mm_budget.mm_max_toks_per_item: + return {} # No tower modalities (embed-only mode) + dummy_modality = mm_budget.get_modality_with_max_tokens() return self._get_mm_dummy_batch(dummy_modality, num_seqs) @@ -5116,40 +5119,50 @@ class GPUModelRunner( assert mm_budget is not None if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ - dummy_modality - ] + if not mm_budget.mm_max_toks_per_item: + # All modality limits are 0 — embedding-only mode. + # Budget is non-zero for embedding storage, but + # there's no encoder to profile. + logger.info( + "Skipping encoder profiling for embedding-only " + "mode (all modality limits=0 with " + "enable_mm_embeds=True).", + ) + else: + # NOTE: Currently model is profiled with a single + # non-text modality with the max possible input + # tokens even when it supports multiple. + dummy_modality = mm_budget.get_modality_with_max_tokens() + max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ + dummy_modality + ] - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) + logger.info( + "Encoder cache will be initialized with a " + "budget of %s tokens, and profiled with " + "%s %s items of the maximum feature size.", + encoder_budget, + max_mm_items_per_batch, + dummy_modality, + ) - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) + # Create dummy batch of multimodal inputs. + batched_dummy_mm_inputs = self._get_mm_dummy_batch( + dummy_modality, + max_mm_items_per_batch, + ) - # Run multimodal encoder. - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) + # Run multimodal encoder. + dummy_encoder_outputs = self.model.embed_multimodal( + **batched_dummy_mm_inputs + ) - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - for i, output in enumerate(dummy_encoder_outputs): - self.encoder_cache[f"tmp_{i}"] = output + sanity_check_mm_encoder_outputs( + dummy_encoder_outputs, + expected_num_items=max_mm_items_per_batch, + ) + for i, output in enumerate(dummy_encoder_outputs): + self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers hidden_states, last_hidden_states = self._dummy_run(