[VLM] Limit multimodal input cache by memory (#14805)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,11 +3,11 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_SIZE
|
||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.utils import LRUCache
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -30,7 +30,7 @@ logger = init_logger(__name__)
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching). This cache size is set by the environment
|
||||
# variable VLLM_MM_INPUT_CACHE_SIZE.
|
||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||
|
||||
|
||||
# TODO(ywang96): Deprecate this class once all multimodal models migrate to use
|
||||
@@ -50,18 +50,20 @@ class MMInputCacheClient:
|
||||
|
||||
# Init cache
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUCache[str,
|
||||
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
self.mm_cache_hits = 0
|
||||
self.mm_cache_total = 0
|
||||
self.mm_debug_cache_hits = 0
|
||||
self.mm_debug_cache_total = 0
|
||||
|
||||
def cache_hit_ratio(self, steps):
|
||||
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
||||
total = self.mm_debug_cache_total
|
||||
|
||||
if total > 0 and total % steps == 0:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_cache_hits / self.mm_cache_total)
|
||||
self.mm_debug_cache_hits / total)
|
||||
|
||||
# NOTE: process_inputs only supports image inputs since all multimodal
|
||||
# models with other modalities have migrated to use merged preprocessor.
|
||||
@@ -71,7 +73,7 @@ class MMInputCacheClient:
|
||||
mm_hashes: Optional[list[str]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[list[MultiModalKwargs]],
|
||||
) -> list[MultiModalKwargs]:
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
@@ -88,7 +90,7 @@ class MMInputCacheClient:
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_inputs: list[MultiModalKwargs] = []
|
||||
ret_inputs: list[Optional[MultiModalKwargs]] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
|
||||
@@ -99,7 +101,7 @@ class MMInputCacheClient:
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
|
||||
self.mm_cache_total += 1
|
||||
self.mm_debug_cache_total += 1
|
||||
if mm_input is None:
|
||||
if precomputed_mm_inputs is not None:
|
||||
# Reuse precomputed input (for merged preprocessor)
|
||||
@@ -114,9 +116,9 @@ class MMInputCacheClient:
|
||||
if self.use_cache:
|
||||
# Add to cache
|
||||
assert mm_hash is not None
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
else:
|
||||
self.mm_cache_hits += 1
|
||||
self.mm_debug_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
ret_inputs.append(mm_input)
|
||||
@@ -128,14 +130,14 @@ class MMInputCacheServer:
|
||||
|
||||
def __init__(self, model_config):
|
||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||
self.mm_cache = LRUCache[str,
|
||||
MultiModalKwargs](VLLM_MM_INPUT_CACHE_SIZE)
|
||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||
MultiModalKwargs)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_inputs: list[Optional[MultiModalKwargs]],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargs]:
|
||||
) -> list[Optional[MultiModalKwargs]]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
if not self.use_cache:
|
||||
@@ -148,7 +150,7 @@ class MMInputCacheServer:
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
assert mm_input is not None
|
||||
else:
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
self.mm_cache[mm_hash] = mm_input
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user