[Fix] Enable mm_processor_cache with vision LoRA (#31927)
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
This commit is contained in:
@@ -24,10 +24,12 @@ from vllm.multimodal.cache import (
|
||||
)
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalFieldElem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField,
|
||||
PlaceholderRange,
|
||||
)
|
||||
from vllm.multimodal.processing import PromptInsertion
|
||||
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
|
||||
@@ -518,3 +520,40 @@ def test_cache_eviction_shm_cache():
|
||||
receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock())
|
||||
|
||||
_run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes)
|
||||
|
||||
|
||||
def test_processor_cache_shared_across_loras():
|
||||
"""Test that processor cache uses mm_hash to share data across LoRAs."""
|
||||
model_config = ModelConfig(
|
||||
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
mm_processor_cache_gb=1,
|
||||
)
|
||||
receiver_cache = MultiModalReceiverCache(model_config)
|
||||
|
||||
base_mm_hash = "image_hash_abc123"
|
||||
lora_a_identifier = f"12345:{base_mm_hash}"
|
||||
lora_b_identifier = f"67890:{base_mm_hash}"
|
||||
|
||||
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
|
||||
|
||||
feature_lora_a = MultiModalFeatureSpec(
|
||||
data=item_data,
|
||||
modality="image",
|
||||
identifier=lora_a_identifier,
|
||||
mm_position=PlaceholderRange(offset=0, length=100),
|
||||
mm_hash=base_mm_hash,
|
||||
)
|
||||
|
||||
receiver_cache.get_and_update_features([feature_lora_a])
|
||||
assert base_mm_hash in receiver_cache._cache
|
||||
|
||||
feature_lora_b = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier=lora_b_identifier,
|
||||
mm_position=PlaceholderRange(offset=0, length=100),
|
||||
mm_hash=base_mm_hash,
|
||||
)
|
||||
|
||||
receiver_cache.get_and_update_features([feature_lora_b])
|
||||
assert feature_lora_b.data == item_data
|
||||
|
||||
@@ -1649,19 +1649,6 @@ class EngineArgs:
|
||||
else None
|
||||
)
|
||||
|
||||
if (
|
||||
lora_config is not None
|
||||
and lora_config.enable_tower_connector_lora
|
||||
and self.mm_processor_cache_gb != 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, enable_tower_connector_lora is "
|
||||
"incompatible with the multi-modal processor cache. "
|
||||
"When enable_tower_connector_lora is set, "
|
||||
"mm_processor_cache_gb must be 0, got %s",
|
||||
self.mm_processor_cache_gb,
|
||||
)
|
||||
|
||||
if (
|
||||
lora_config is not None
|
||||
and speculative_config is not None
|
||||
|
||||
@@ -635,12 +635,17 @@ class BaseMultiModalReceiverCache(
|
||||
Update multimodal features with cached encoder outputs.
|
||||
Touch all identifier at first before update to avoid
|
||||
item in updated list evict during update.
|
||||
|
||||
Uses mm_hash for cache key to share across LoRAs (falls back to
|
||||
identifier for backward compatibility).
|
||||
"""
|
||||
for feature in mm_features:
|
||||
self.touch_receiver_cache_item(feature.identifier, feature.data)
|
||||
cache_key = feature.mm_hash or feature.identifier
|
||||
self.touch_receiver_cache_item(cache_key, feature.data)
|
||||
|
||||
for feature in mm_features:
|
||||
feature.data = self.get_and_update_item(feature.data, feature.identifier)
|
||||
cache_key = feature.mm_hash or feature.identifier
|
||||
feature.data = self.get_and_update_item(feature.data, cache_key)
|
||||
return mm_features
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -330,6 +330,9 @@ class MultiModalFeatureSpec:
|
||||
mm_position: PlaceholderRange
|
||||
"""e.g., PlaceholderRange(offset=2, length=336)"""
|
||||
|
||||
mm_hash: str | None = None
|
||||
"""Base mm_hash for processor cache (without LoRA prefix)."""
|
||||
|
||||
@staticmethod
|
||||
def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
|
||||
kwargs = defaultdict[str, list[NestedTensors]](list)
|
||||
|
||||
@@ -562,15 +562,17 @@ class InputProcessor:
|
||||
|
||||
mm_features = []
|
||||
for modality, idx in sorted_mm_idxs:
|
||||
base_mm_hash = decoder_mm_hashes[modality][idx]
|
||||
mm_features.append(
|
||||
MultiModalFeatureSpec(
|
||||
data=decoder_mm_inputs[modality][idx],
|
||||
modality=modality,
|
||||
identifier=self._get_mm_identifier(
|
||||
decoder_mm_hashes[modality][idx],
|
||||
base_mm_hash,
|
||||
lora_request,
|
||||
),
|
||||
mm_position=decoder_mm_positions[modality][idx],
|
||||
mm_hash=base_mm_hash,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user