[2/N] Move cache factories to MM registry (#32382)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -19,8 +19,6 @@ from vllm.multimodal.cache import (
|
||||
MultiModalReceiverCache,
|
||||
ShmObjectStoreReceiverCache,
|
||||
ShmObjectStoreSenderCache,
|
||||
engine_receiver_cache_from_config,
|
||||
processor_cache_from_config,
|
||||
)
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -132,10 +130,10 @@ def _compare_caches(
|
||||
n_iter: int = 100,
|
||||
seed: int = 0,
|
||||
):
|
||||
cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY)
|
||||
cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY)
|
||||
cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY)
|
||||
cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY)
|
||||
cache_0_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_0)
|
||||
cache_0_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_0)
|
||||
cache_1_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_1)
|
||||
cache_1_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_1)
|
||||
|
||||
cache_size_gb = max(
|
||||
config_0.model_config.multimodal_config.mm_processor_cache_gb,
|
||||
|
||||
@@ -6,9 +6,8 @@ import pytest
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.multimodal import MultiModalUUIDDict
|
||||
from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import input_processor as input_processor_mod
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
@@ -36,9 +35,9 @@ def _mock_input_processor(
|
||||
raising=True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
input_processor_mod,
|
||||
MultiModalRegistry,
|
||||
"processor_cache_from_config",
|
||||
lambda vllm_config, mm_registry: None,
|
||||
lambda self, vllm_config: None,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,9 +135,15 @@ class LoRAModelManager:
|
||||
llm_punica_wrapper
|
||||
)
|
||||
|
||||
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None:
|
||||
self.supports_tower_connector_lora = False
|
||||
def _maybe_init_mm(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
self.supports_tower_connector_lora = False
|
||||
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
|
||||
# Only one language model can be included in the model.
|
||||
@@ -154,9 +160,7 @@ class LoRAModelManager:
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor(
|
||||
model_config
|
||||
).info
|
||||
self.mm_processor_info = mm_registry.create_processor(model_config).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
self.model, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
@@ -169,11 +173,7 @@ class LoRAModelManager:
|
||||
"GitHub if you encounter them."
|
||||
)
|
||||
|
||||
mm_budget = MultiModalBudget(
|
||||
model_config,
|
||||
vllm_config.scheduler_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
)
|
||||
mm_budget = MultiModalBudget(vllm_config, mm_registry)
|
||||
limit_per_prompt: int = max(
|
||||
self.mm_processor_info.get_allowed_mm_limits().values()
|
||||
)
|
||||
|
||||
@@ -35,7 +35,6 @@ if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing.processor import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -561,67 +560,6 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (
|
||||
parallel_config._api_process_count == 1
|
||||
and parallel_config.data_parallel_size == 1
|
||||
) or parallel_config.data_parallel_external_lb
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
|
||||
):
|
||||
@@ -780,50 +718,3 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
|
||||
@@ -2,14 +2,23 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
|
||||
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.config.observability import ObservabilityConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .cache import (
|
||||
BaseMultiModalProcessorCache,
|
||||
BaseMultiModalReceiverCache,
|
||||
MultiModalProcessorOnlyCache,
|
||||
MultiModalProcessorSenderCache,
|
||||
MultiModalReceiverCache,
|
||||
ShmObjectStoreReceiverCache,
|
||||
ShmObjectStoreSenderCache,
|
||||
)
|
||||
from .inputs import MultiModalInputs
|
||||
from .processing import (
|
||||
BaseDummyInputsBuilder,
|
||||
@@ -19,7 +28,7 @@ from .processing import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, ObservabilityConfig
|
||||
from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -355,3 +364,84 @@ class MultiModalRegistry:
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
|
||||
def _get_cache_type(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> Literal[None, "processor_only", "lru", "shm"]:
|
||||
model_config = vllm_config.model_config
|
||||
if not self.supports_multimodal_inputs(model_config):
|
||||
return None
|
||||
|
||||
# Check if the cache is disabled.
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
if mm_config.mm_processor_cache_gb <= 0:
|
||||
return None
|
||||
|
||||
# Check if IPC caching is supported.
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_ipc_supported = parallel_config._api_process_count == 1 and (
|
||||
parallel_config.data_parallel_size == 1
|
||||
or parallel_config.data_parallel_external_lb
|
||||
)
|
||||
|
||||
if not is_ipc_supported:
|
||||
return "processor_only"
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_type
|
||||
|
||||
def processor_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> BaseMultiModalProcessorCache | None:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type is None:
|
||||
return None
|
||||
elif cache_type == "processor_only":
|
||||
return MultiModalProcessorOnlyCache(vllm_config.model_config)
|
||||
elif cache_type == "lru":
|
||||
return MultiModalProcessorSenderCache(vllm_config.model_config)
|
||||
elif cache_type == "shm":
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> MultiModalProcessorOnlyCache | None:
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type is None:
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(vllm_config.model_config)
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""Return a `BaseMultiModalReceiverCache` for the engine process."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type in (None, "processor_only", "shm"):
|
||||
return None
|
||||
elif cache_type == "lru":
|
||||
return MultiModalReceiverCache(vllm_config.model_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> BaseMultiModalReceiverCache | None:
|
||||
"""Return a `BaseMultiModalReceiverCache` for the worker process."""
|
||||
cache_type = self._get_cache_type(vllm_config)
|
||||
if cache_type in (None, "processor_only", "lru"):
|
||||
return None
|
||||
elif cache_type == "shm":
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache type: {cache_type!r}")
|
||||
|
||||
@@ -23,7 +23,6 @@ from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import engine_receiver_cache_from_config
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.utils.gc_utils import (
|
||||
@@ -149,8 +148,8 @@ class EngineCore:
|
||||
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
|
||||
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||
vllm_config, mm_registry
|
||||
self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
|
||||
vllm_config
|
||||
)
|
||||
|
||||
# If a KV connector is initialized for scheduler, we want to collect
|
||||
|
||||
@@ -14,7 +14,6 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
@@ -58,7 +57,7 @@ class InputProcessor:
|
||||
self.generation_config_fields = self.model_config.try_get_generation_config()
|
||||
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry)
|
||||
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
|
||||
@@ -623,11 +623,7 @@ class GPUModelRunner(
|
||||
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
|
||||
|
||||
self.mm_budget = (
|
||||
MultiModalBudget(
|
||||
self.model_config,
|
||||
self.scheduler_config,
|
||||
self.mm_registry,
|
||||
)
|
||||
MultiModalBudget(self.vllm_config, self.mm_registry)
|
||||
if self.supports_mm_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
@@ -8,11 +8,10 @@ import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
@@ -28,16 +27,15 @@ class MultiModalBudget:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
vllm_config: VllmConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = processor_only_cache_from_config(model_config, mm_registry)
|
||||
self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
@@ -12,7 +12,6 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
@@ -303,11 +302,12 @@ class WorkerWrapperBase:
|
||||
|
||||
self.mm_receiver_cache = None
|
||||
else:
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.mm_receiver_cache = (
|
||||
MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
|
||||
vllm_config,
|
||||
MULTIMODAL_REGISTRY,
|
||||
shared_worker_lock,
|
||||
)
|
||||
)
|
||||
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
# To make vLLM config available during worker initialization
|
||||
|
||||
Reference in New Issue
Block a user