From cbbae38f9368b6c35d9b9295bf4ceee1e6452750 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 15 Jan 2026 17:02:30 +0800 Subject: [PATCH] [2/N] Move cache factories to MM registry (#32382) Signed-off-by: DarkLight1337 --- tests/multimodal/test_cache.py | 10 +- .../engine/test_process_multi_modal_uuids.py | 7 +- vllm/lora/model_manager.py | 20 ++-- vllm/multimodal/cache.py | 109 ------------------ vllm/multimodal/registry.py | 96 ++++++++++++++- vllm/v1/engine/core.py | 5 +- vllm/v1/engine/input_processor.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 6 +- vllm/v1/worker/utils.py | 12 +- vllm/v1/worker/worker_base.py | 10 +- 10 files changed, 124 insertions(+), 154 deletions(-) diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 0a8d4f737..36220e8f3 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -19,8 +19,6 @@ from vllm.multimodal.cache import ( MultiModalReceiverCache, ShmObjectStoreReceiverCache, ShmObjectStoreSenderCache, - engine_receiver_cache_from_config, - processor_cache_from_config, ) from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import ( @@ -132,10 +130,10 @@ def _compare_caches( n_iter: int = 100, seed: int = 0, ): - cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY) - cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY) - cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY) - cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY) + cache_0_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_0) + cache_0_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_0) + cache_1_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_1) + cache_1_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_1) cache_size_gb = max( config_0.model_config.multimodal_config.mm_processor_cache_gb, diff --git a/tests/v1/engine/test_process_multi_modal_uuids.py b/tests/v1/engine/test_process_multi_modal_uuids.py index 1a16e3913..dbf9ffd29 100644 --- a/tests/v1/engine/test_process_multi_modal_uuids.py +++ b/tests/v1/engine/test_process_multi_modal_uuids.py @@ -6,9 +6,8 @@ import pytest from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig -from vllm.multimodal import MultiModalUUIDDict +from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict from vllm.sampling_params import SamplingParams -from vllm.v1.engine import input_processor as input_processor_mod from vllm.v1.engine.input_processor import InputProcessor cherry_pil_image = ImageAsset("cherry_blossom").pil_image @@ -36,9 +35,9 @@ def _mock_input_processor( raising=True, ) monkeypatch.setattr( - input_processor_mod, + MultiModalRegistry, "processor_cache_from_config", - lambda vllm_config, mm_registry: None, + lambda self, vllm_config: None, raising=True, ) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 70da246f2..eb9552e0a 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -135,9 +135,15 @@ class LoRAModelManager: llm_punica_wrapper ) - def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None: - self.supports_tower_connector_lora = False + def _maybe_init_mm( + self, + vllm_config: VllmConfig, + max_num_batched_tokens: int, + ) -> None: model_config: ModelConfig = vllm_config.model_config + mm_registry = MULTIMODAL_REGISTRY + + self.supports_tower_connector_lora = False self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() # Only one language model can be included in the model. @@ -154,9 +160,7 @@ class LoRAModelManager: self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper if self.lora_config.enable_tower_connector_lora: - self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor( - model_config - ).info + self.mm_processor_info = mm_registry.create_processor(model_config).info self.supports_tower_connector_lora = self.supports_mm and hasattr( self.model, "get_num_mm_encoder_tokens" ) @@ -169,11 +173,7 @@ class LoRAModelManager: "GitHub if you encounter them." ) - mm_budget = MultiModalBudget( - model_config, - vllm_config.scheduler_config, - MULTIMODAL_REGISTRY, - ) + mm_budget = MultiModalBudget(vllm_config, mm_registry) limit_per_prompt: int = max( self.mm_processor_info.get_allowed_mm_limits().values() ) diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 7f97ae359..cb17f7fdd 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -35,7 +35,6 @@ if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig from .processing.processor import ResolvedPromptUpdate - from .registry import MultiModalRegistry logger = init_logger(__name__) @@ -561,67 +560,6 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): return mm_item -def _enable_processor_cache( - model_config: "ModelConfig", - mm_registry: "MultiModalRegistry", -) -> bool: - if not mm_registry.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - return mm_config.mm_processor_cache_gb > 0 - - -def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: - parallel_config = vllm_config.parallel_config - supports_ipc_cache = ( - parallel_config._api_process_count == 1 - and parallel_config.data_parallel_size == 1 - ) or parallel_config.data_parallel_external_lb - - return supports_ipc_cache - - -def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool: - """Whether the shared memory based cache should be enabled.""" - - if not _enable_ipc_cache(vllm_config): - return False - - mm_config = vllm_config.model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_type == "shm" - - -def processor_cache_from_config( - vllm_config: "VllmConfig", - mm_registry: "MultiModalRegistry", -) -> BaseMultiModalProcessorCache | None: - """Return a `BaseMultiModalProcessorCache`, if enabled.""" - model_config = vllm_config.model_config - - if not _enable_processor_cache(model_config, mm_registry): - return None - - if not _enable_ipc_cache(vllm_config): - return MultiModalProcessorOnlyCache(model_config) - - if not _enable_mm_input_shm_cache(vllm_config): - return MultiModalProcessorSenderCache(model_config) - return ShmObjectStoreSenderCache(vllm_config) - - -def processor_only_cache_from_config( - model_config: "ModelConfig", - mm_registry: "MultiModalRegistry", -): - """Return a `MultiModalProcessorOnlyCache`, if enabled.""" - if not _enable_processor_cache(model_config, mm_registry): - return None - - return MultiModalProcessorOnlyCache(model_config) - - class BaseMultiModalReceiverCache( BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem] ): @@ -780,50 +718,3 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): @override def clear_cache(self) -> None: self._shm_cache.clear() - - -def engine_receiver_cache_from_config( - vllm_config: "VllmConfig", - mm_registry: "MultiModalRegistry", -) -> BaseMultiModalReceiverCache | None: - """ - This is used in the engine process. - Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and - mm_processor_cache_type=="lru". - """ - model_config = vllm_config.model_config - - if not _enable_processor_cache(model_config, mm_registry): - return None - - if not _enable_ipc_cache(vllm_config): - return None - - if not _enable_mm_input_shm_cache(vllm_config): - return MultiModalReceiverCache(model_config) - - return None - - -def worker_receiver_cache_from_config( - vllm_config: "VllmConfig", - mm_registry: "MultiModalRegistry", - shared_worker_lock: LockType, -) -> BaseMultiModalReceiverCache | None: - """ - This is used in the worker process. - Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and - mm_processor_cache_type=="shm". - """ - model_config = vllm_config.model_config - - if not _enable_processor_cache(model_config, mm_registry): - return None - - if not _enable_ipc_cache(vllm_config): - return None - - if not _enable_mm_input_shm_cache(vllm_config): - return None - - return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index e29c7bda0..117279369 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,14 +2,23 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast +from multiprocessing.synchronize import Lock as LockType +from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast from vllm.config.multimodal import BaseDummyOptions from vllm.config.observability import ObservabilityConfig from vllm.logger import init_logger from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config -from .cache import BaseMultiModalProcessorCache +from .cache import ( + BaseMultiModalProcessorCache, + BaseMultiModalReceiverCache, + MultiModalProcessorOnlyCache, + MultiModalProcessorSenderCache, + MultiModalReceiverCache, + ShmObjectStoreReceiverCache, + ShmObjectStoreSenderCache, +) from .inputs import MultiModalInputs from .processing import ( BaseDummyInputsBuilder, @@ -19,7 +28,7 @@ from .processing import ( ) if TYPE_CHECKING: - from vllm.config import ModelConfig, ObservabilityConfig + from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig from vllm.model_executor.models.interfaces import SupportsMultiModal logger = init_logger(__name__) @@ -355,3 +364,84 @@ class MultiModalRegistry: first_modality = next(iter(max_tokens)) return max_tokens[first_modality] + + def _get_cache_type( + self, + vllm_config: "VllmConfig", + ) -> Literal[None, "processor_only", "lru", "shm"]: + model_config = vllm_config.model_config + if not self.supports_multimodal_inputs(model_config): + return None + + # Check if the cache is disabled. + mm_config = model_config.get_multimodal_config() + if mm_config.mm_processor_cache_gb <= 0: + return None + + # Check if IPC caching is supported. + parallel_config = vllm_config.parallel_config + is_ipc_supported = parallel_config._api_process_count == 1 and ( + parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb + ) + + if not is_ipc_supported: + return "processor_only" + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_type + + def processor_cache_from_config( + self, + vllm_config: "VllmConfig", + ) -> BaseMultiModalProcessorCache | None: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + cache_type = self._get_cache_type(vllm_config) + if cache_type is None: + return None + elif cache_type == "processor_only": + return MultiModalProcessorOnlyCache(vllm_config.model_config) + elif cache_type == "lru": + return MultiModalProcessorSenderCache(vllm_config.model_config) + elif cache_type == "shm": + return ShmObjectStoreSenderCache(vllm_config) + else: + raise ValueError(f"Unknown cache type: {cache_type!r}") + + def processor_only_cache_from_config( + self, + vllm_config: "VllmConfig", + ) -> MultiModalProcessorOnlyCache | None: + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + cache_type = self._get_cache_type(vllm_config) + if cache_type is None: + return None + + return MultiModalProcessorOnlyCache(vllm_config.model_config) + + def engine_receiver_cache_from_config( + self, + vllm_config: "VllmConfig", + ) -> BaseMultiModalReceiverCache | None: + """Return a `BaseMultiModalReceiverCache` for the engine process.""" + cache_type = self._get_cache_type(vllm_config) + if cache_type in (None, "processor_only", "shm"): + return None + elif cache_type == "lru": + return MultiModalReceiverCache(vllm_config.model_config) + else: + raise ValueError(f"Unknown cache type: {cache_type!r}") + + def worker_receiver_cache_from_config( + self, + vllm_config: "VllmConfig", + shared_worker_lock: LockType, + ) -> BaseMultiModalReceiverCache | None: + """Return a `BaseMultiModalReceiverCache` for the worker process.""" + cache_type = self._get_cache_type(vllm_config) + if cache_type in (None, "processor_only", "lru"): + return None + elif cache_type == "shm": + return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock) + else: + raise ValueError(f"Unknown cache type: {cache_type!r}") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3ae0b3179..141e5a459 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -23,7 +23,6 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import engine_receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.utils.gc_utils import ( @@ -149,8 +148,8 @@ class EngineCore: self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore self.mm_registry = mm_registry = MULTIMODAL_REGISTRY - self.mm_receiver_cache = engine_receiver_cache_from_config( - vllm_config, mm_registry + self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config( + vllm_config ) # If a KV connector is initialized for scheduler, we want to collect diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 6404cb67e..318aa51ce 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -14,7 +14,6 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing.context import set_request_id @@ -58,7 +57,7 @@ class InputProcessor: self.generation_config_fields = self.model_config.try_get_generation_config() self.mm_registry = mm_registry - self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) + self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) self.input_preprocessor = InputPreprocessor( self.model_config, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 525cfede1..00e401f41 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -623,11 +623,7 @@ class GPUModelRunner( self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.mm_budget = ( - MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) + MultiModalBudget(self.vllm_config, self.mm_registry) if self.supports_mm_inputs else None ) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 7fd6161a9..810160046 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -8,11 +8,10 @@ import torch from typing_extensions import deprecated from vllm.attention.layer import Attention -from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.config import CacheConfig, VllmConfig from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform from vllm.utils.mem_utils import MemorySnapshot, format_gib @@ -28,16 +27,15 @@ class MultiModalBudget: def __init__( self, - model_config: ModelConfig, - scheduler_config: SchedulerConfig, + vllm_config: VllmConfig, mm_registry: MultiModalRegistry, ) -> None: super().__init__() - self.model_config = model_config - self.scheduler_config = scheduler_config + self.model_config = model_config = vllm_config.model_config + self.scheduler_config = scheduler_config = vllm_config.scheduler_config self.mm_registry = mm_registry - self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) + self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index df53b87af..d34eb5253 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -12,7 +12,6 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.cache import worker_receiver_cache_from_config from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.system_utils import update_environment_variables from vllm.v1.kv_cache_interface import KVCacheSpec @@ -303,10 +302,11 @@ class WorkerWrapperBase: self.mm_receiver_cache = None else: - self.mm_receiver_cache = worker_receiver_cache_from_config( - vllm_config, - MULTIMODAL_REGISTRY, - shared_worker_lock, + self.mm_receiver_cache = ( + MULTIMODAL_REGISTRY.worker_receiver_cache_from_config( + vllm_config, + shared_worker_lock, + ) ) with set_current_vllm_config(self.vllm_config):