[2/N] Move cache factories to MM registry (#32382)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-15 17:02:30 +08:00
committed by GitHub
parent cdba4c74b3
commit cbbae38f93
10 changed files with 124 additions and 154 deletions

View File

@@ -19,8 +19,6 @@ from vllm.multimodal.cache import (
MultiModalReceiverCache,
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
engine_receiver_cache_from_config,
processor_cache_from_config,
)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (
@@ -132,10 +130,10 @@ def _compare_caches(
n_iter: int = 100,
seed: int = 0,
):
cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY)
cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY)
cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY)
cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY)
cache_0_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_0)
cache_0_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_0)
cache_1_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_1)
cache_1_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_1)
cache_size_gb = max(
config_0.model_config.multimodal_config.mm_processor_cache_gb,

View File

@@ -6,9 +6,8 @@ import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import input_processor as input_processor_mod
from vllm.v1.engine.input_processor import InputProcessor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
@@ -36,9 +35,9 @@ def _mock_input_processor(
raising=True,
)
monkeypatch.setattr(
input_processor_mod,
MultiModalRegistry,
"processor_cache_from_config",
lambda vllm_config, mm_registry: None,
lambda self, vllm_config: None,
raising=True,
)

View File

@@ -135,9 +135,15 @@ class LoRAModelManager:
llm_punica_wrapper
)
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
self.supports_tower_connector_lora = False
def _maybe_init_mm(
self,
vllm_config: VllmConfig,
max_num_batched_tokens: int,
) -> None:
model_config: ModelConfig = vllm_config.model_config
mm_registry = MULTIMODAL_REGISTRY
self.supports_tower_connector_lora = False
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
# Only one language model can be included in the model.
@@ -154,9 +160,7 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
model_config
).info
self.mm_processor_info = mm_registry.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens"
)
@@ -169,11 +173,7 @@ class LoRAModelManager:
"GitHub if you encounter them."
)
mm_budget = MultiModalBudget(
model_config,
vllm_config.scheduler_config,
MULTIMODAL_REGISTRY,
)
mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt: int = max(
self.mm_processor_info.get_allowed_mm_limits().values()
)

View File

@@ -35,7 +35,6 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing.processor import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__)
@@ -561,67 +560,6 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
return mm_item
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (
parallel_config._api_process_count == 1
and parallel_config.data_parallel_size == 1
) or parallel_config.data_parallel_external_lb
return supports_ipc_cache
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
"""Whether the shared memory based cache should be enabled."""
if not _enable_ipc_cache(vllm_config):
return False
mm_config = vllm_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type == "shm"
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalProcessorSenderCache(model_config)
return ShmObjectStoreSenderCache(vllm_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache(
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
):
@@ -780,50 +718,3 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
@override
def clear_cache(self) -> None:
self._shm_cache.clear()
def engine_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the engine process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="lru".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalReceiverCache(model_config)
return None
def worker_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the worker process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return None
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)

View File

@@ -2,14 +2,23 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache
from .cache import (
BaseMultiModalProcessorCache,
BaseMultiModalReceiverCache,
MultiModalProcessorOnlyCache,
MultiModalProcessorSenderCache,
MultiModalReceiverCache,
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
)
from .inputs import MultiModalInputs
from .processing import (
BaseDummyInputsBuilder,
@@ -19,7 +28,7 @@ from .processing import (
)
if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig
from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__)
@@ -355,3 +364,84 @@ class MultiModalRegistry:
first_modality = next(iter(max_tokens))
return max_tokens[first_modality]
def _get_cache_type(
self,
vllm_config: "VllmConfig",
) -> Literal[None, "processor_only", "lru", "shm"]:
model_config = vllm_config.model_config
if not self.supports_multimodal_inputs(model_config):
return None
# Check if the cache is disabled.
mm_config = model_config.get_multimodal_config()
if mm_config.mm_processor_cache_gb <= 0:
return None
# Check if IPC caching is supported.
parallel_config = vllm_config.parallel_config
is_ipc_supported = parallel_config._api_process_count == 1 and (
parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb
)
if not is_ipc_supported:
return "processor_only"
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type
def processor_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
elif cache_type == "processor_only":
return MultiModalProcessorOnlyCache(vllm_config.model_config)
elif cache_type == "lru":
return MultiModalProcessorSenderCache(vllm_config.model_config)
elif cache_type == "shm":
return ShmObjectStoreSenderCache(vllm_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def processor_only_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> MultiModalProcessorOnlyCache | None:
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
return MultiModalProcessorOnlyCache(vllm_config.model_config)
def engine_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the engine process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "shm"):
return None
elif cache_type == "lru":
return MultiModalReceiverCache(vllm_config.model_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def worker_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the worker process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "lru"):
return None
elif cache_type == "shm":
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")

View File

@@ -23,7 +23,6 @@ from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import engine_receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import (
@@ -149,8 +148,8 @@ class EngineCore:
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config(
vllm_config, mm_registry
self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
vllm_config
)
# If a KV connector is initialized for scheduler, we want to collect

View File

@@ -14,7 +14,6 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing.context import set_request_id
@@ -58,7 +57,7 @@ class InputProcessor:
self.generation_config_fields = self.model_config.try_get_generation_config()
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.input_preprocessor = InputPreprocessor(
self.model_config,

View File

@@ -623,11 +623,7 @@ class GPUModelRunner(
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = (
MultiModalBudget(
self.model_config,
self.scheduler_config,
self.mm_registry,
)
MultiModalBudget(self.vllm_config, self.mm_registry)
if self.supports_mm_inputs
else None
)

View File

@@ -8,11 +8,10 @@ import torch
from typing_extensions import deprecated
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
@@ -28,16 +27,15 @@ class MultiModalBudget:
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config
self.scheduler_config = scheduler_config
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(model_config, mm_registry)
self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs

View File

@@ -12,7 +12,6 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec
@@ -303,10 +302,11 @@ class WorkerWrapperBase:
self.mm_receiver_cache = None
else:
self.mm_receiver_cache = worker_receiver_cache_from_config(
vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock,
self.mm_receiver_cache = (
MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
vllm_config,
shared_worker_lock,
)
)
with set_current_vllm_config(self.vllm_config):