diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index e641b1111..0a8d4f737 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -24,10 +24,12 @@ from vllm.multimodal.cache import ( ) from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import ( + MultiModalFeatureSpec, MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField, + PlaceholderRange, ) from vllm.multimodal.processing import PromptInsertion from vllm.utils.mem_constants import GiB_bytes, MiB_bytes @@ -518,3 +520,40 @@ def test_cache_eviction_shm_cache(): receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock()) _run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes) + + +def test_processor_cache_shared_across_loras(): + """Test that processor cache uses mm_hash to share data across LoRAs.""" + model_config = ModelConfig( + model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", + mm_processor_cache_gb=1, + ) + receiver_cache = MultiModalReceiverCache(model_config) + + base_mm_hash = "image_hash_abc123" + lora_a_identifier = f"12345:{base_mm_hash}" + lora_b_identifier = f"67890:{base_mm_hash}" + + item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024) + + feature_lora_a = MultiModalFeatureSpec( + data=item_data, + modality="image", + identifier=lora_a_identifier, + mm_position=PlaceholderRange(offset=0, length=100), + mm_hash=base_mm_hash, + ) + + receiver_cache.get_and_update_features([feature_lora_a]) + assert base_mm_hash in receiver_cache._cache + + feature_lora_b = MultiModalFeatureSpec( + data=None, + modality="image", + identifier=lora_b_identifier, + mm_position=PlaceholderRange(offset=0, length=100), + mm_hash=base_mm_hash, + ) + + receiver_cache.get_and_update_features([feature_lora_b]) + assert feature_lora_b.data == item_data diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 98f1cfbd5..77201f668 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1649,19 +1649,6 @@ class EngineArgs: else None ) - if ( - lora_config is not None - and lora_config.enable_tower_connector_lora - and self.mm_processor_cache_gb != 0 - ): - raise ValueError( - "Currently, enable_tower_connector_lora is " - "incompatible with the multi-modal processor cache. " - "When enable_tower_connector_lora is set, " - "mm_processor_cache_gb must be 0, got %s", - self.mm_processor_cache_gb, - ) - if ( lora_config is not None and speculative_config is not None diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index f22a14a1d..41397a26e 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -635,12 +635,17 @@ class BaseMultiModalReceiverCache( Update multimodal features with cached encoder outputs. Touch all identifier at first before update to avoid item in updated list evict during update. + + Uses mm_hash for cache key to share across LoRAs (falls back to + identifier for backward compatibility). """ for feature in mm_features: - self.touch_receiver_cache_item(feature.identifier, feature.data) + cache_key = feature.mm_hash or feature.identifier + self.touch_receiver_cache_item(cache_key, feature.data) for feature in mm_features: - feature.data = self.get_and_update_item(feature.data, feature.identifier) + cache_key = feature.mm_hash or feature.identifier + feature.data = self.get_and_update_item(feature.data, cache_key) return mm_features @abstractmethod diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index afd782870..bd49d7192 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -330,6 +330,9 @@ class MultiModalFeatureSpec: mm_position: PlaceholderRange """e.g., PlaceholderRange(offset=2, length=336)""" + mm_hash: str | None = None + """Base mm_hash for processor cache (without LoRA prefix).""" + @staticmethod def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): kwargs = defaultdict[str, list[NestedTensors]](list) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 85f2f4053..8275dc409 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -562,15 +562,17 @@ class InputProcessor: mm_features = [] for modality, idx in sorted_mm_idxs: + base_mm_hash = decoder_mm_hashes[modality][idx] mm_features.append( MultiModalFeatureSpec( data=decoder_mm_inputs[modality][idx], modality=modality, identifier=self._get_mm_identifier( - decoder_mm_hashes[modality][idx], + base_mm_hash, lora_request, ), mm_position=decoder_mm_positions[modality][idx], + mm_hash=base_mm_hash, ) )