[Refactor] Move profiling methods to MM budget (#33559)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-02 23:27:00 +08:00
committed by GitHub
parent 528e9b1490
commit d7e17aaacd
8 changed files with 201 additions and 296 deletions

View File

@@ -10,7 +10,7 @@ import torch
from torch import nn from torch import nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig, ModelConfig from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import ( from vllm.lora.layers import (
BaseLayerWithLoRA, BaseLayerWithLoRA,
@@ -35,9 +35,9 @@ from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer from vllm.model_executor.models.utils import PPMissingLayer
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.budget import MultiModalBudget
from vllm.utils.cache import LRUCache from vllm.utils.cache import LRUCache
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -142,7 +142,6 @@ class LoRAModelManager:
vllm_config: VllmConfig, vllm_config: VllmConfig,
max_num_batched_tokens: int, max_num_batched_tokens: int,
) -> None: ) -> None:
model_config: ModelConfig = vllm_config.model_config
mm_registry = MULTIMODAL_REGISTRY mm_registry = MULTIMODAL_REGISTRY
self.supports_tower_connector_lora = False self.supports_tower_connector_lora = False
@@ -162,7 +161,6 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora: if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = mm_registry.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens" self.model, "get_num_mm_encoder_tokens"
) )
@@ -176,7 +174,7 @@ class LoRAModelManager:
) )
mm_budget = MultiModalBudget(vllm_config, mm_registry) mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values()) limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
num_encoder_tokens = self.model.get_num_mm_encoder_tokens( num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget() mm_budget.get_encoder_budget()
) )

154
vllm/multimodal/budget.py Normal file
View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from vllm.config import ModelConfig, VllmConfig
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.registry import MultiModalRegistry
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
def get_mm_max_toks_per_item(
model_config: ModelConfig,
mm_registry: MultiModalRegistry,
processor: BaseMultiModalProcessor,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
mm_inputs = mm_registry.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
processor=processor,
)
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
cache = mm_registry.processor_only_cache_from_config(vllm_config)
processor = mm_registry.create_processor(model_config, cache=cache)
self.cache = cache
self.mm_limits = mm_limits = processor.info.allowed_mm_limits
active_modalities = {
modality for modality, limit in mm_limits.items() if limit > 0
}
with set_default_torch_num_threads(): # Avoid hang during startup
all_mm_max_toks_per_item = get_mm_max_toks_per_item(
model_config,
mm_registry,
processor,
mm_counts=dict.fromkeys(active_modalities, 1),
)
mm_max_toks_per_item = {
modality: all_mm_max_toks_per_item[modality]
for modality in active_modalities
}
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
mm_max_toks_per_item,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
mm_max_items_per_prompt = dict[str, int]()
mm_max_items_per_batch = dict[str, int]()
for modality, max_toks_per_item in mm_max_toks_per_item.items():
(
mm_max_items_per_prompt[modality],
mm_max_items_per_batch[modality],
) = self._get_max_items(modality, max_toks_per_item)
self.mm_max_toks_per_item = mm_max_toks_per_item
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
def _get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
if (encoder_budget := self.get_encoder_budget()) == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def get_modality_with_max_tokens(self) -> str:
mm_max_toks_per_item = self.mm_max_toks_per_item
modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()

View File

@@ -130,15 +130,13 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
return False return False
info = self._create_processing_info(model_config, tokenizer=None)
supported_modalities = info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config() mm_config = model_config.get_multimodal_config()
info = self._create_processing_info(model_config, tokenizer=None)
# Check if all supported modalities have limit == 0 # Check if all supported modalities have limit == 0
if all( if all(
mm_config.get_limit_per_prompt(modality) == 0 mm_config.get_limit_per_prompt(modality) == 0
for modality in supported_modalities for modality in info.supported_mm_limits
): ):
logger.info_once( logger.info_once(
"All limits of multimodal modalities supported by the model " "All limits of multimodal modalities supported by the model "
@@ -148,70 +146,6 @@ class MultiModalRegistry:
return True return True
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(
model_config, observability_config, cache=cache
)
if profiler_limits is None:
profiler_limits = processor.info.allowed_mm_limits
mm_counts = {
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
}
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
processor=processor,
)
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
*,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
"""
if not model_config.is_multimodal_model:
return {}
info = self._create_processing_info(model_config, observability_config)
return info.allowed_mm_limits
def register_processor( def register_processor(
self, self,
processor: MultiModalProcessorFactory[_I], processor: MultiModalProcessorFactory[_I],
@@ -303,10 +237,9 @@ class MultiModalRegistry:
def get_dummy_mm_inputs( def get_dummy_mm_inputs(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int],
*, *,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
processor: BaseMultiModalProcessor | None = None, processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs: ) -> MultiModalInputs:
""" """
@@ -317,11 +250,7 @@ class MultiModalRegistry:
seq_len = model_config.max_model_len seq_len = model_config.max_model_len
if processor is None: if processor is None:
processor = self.create_processor( processor = self.create_processor(model_config, cache=cache)
model_config, observability_config, cache=cache
)
if mm_counts is None:
mm_counts = processor.info.allowed_mm_limits
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len, seq_len=seq_len,
@@ -342,26 +271,6 @@ class MultiModalRegistry:
return mm_inputs return mm_inputs
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return 0
assert len(max_tokens) == 1, (
"Encoder-decoder models are expected "
"to implement the multimodal interface with at most one modality."
)
first_modality = next(iter(max_tokens))
return max_tokens[first_modality]
def _get_cache_type( def _get_cache_type(
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",

View File

@@ -6,11 +6,10 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request from vllm.v1.request import Request
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import SchedulerConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -267,60 +266,16 @@ class EncoderCacheManager:
return freed return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Returns:
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)
return compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
return compute_text_encoder_budget(scheduler_config)
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a text-only model.
Args:
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
# Currently text-only encoder-decoder models are not supported
return 0, 0
def compute_mm_encoder_budget( def compute_mm_encoder_budget(
scheduler_config: "SchedulerConfig", scheduler_config: "SchedulerConfig",
max_tokens_by_modality: Mapping[str, int], mm_max_toks_per_item: Mapping[str, int],
) -> tuple[int, int]: ) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler """Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model. configurations for a multimodal model.
Args: Args:
scheduler_config: Scheduler configuration. scheduler_config: Scheduler configuration.
max_tokens_by_modality: The maximum number of tokens for each mm_max_toks_per_item: The maximum number of tokens per item for each
non-text modality. non-text modality.
Returns: Returns:
@@ -330,7 +285,7 @@ def compute_mm_encoder_budget(
from the input sequence. from the input sequence.
""" """
if not max_tokens_by_modality: if not mm_max_toks_per_item:
logger.warning( logger.warning(
"All non-text modalities supported by the model have been " "All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will " "explicitly disabled via limit_mm_per_prompt. Encoder cache will "
@@ -338,7 +293,7 @@ def compute_mm_encoder_budget(
) )
return 0, 0 return 0, 0
max_tokens_per_mm_item = max(max_tokens_by_modality.values()) max_tokens_per_mm_item = max(mm_max_toks_per_item.values())
if ( if (
scheduler_config.disable_chunked_mm_input scheduler_config.disable_chunked_mm_input

View File

@@ -31,10 +31,10 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsReader, RoutedExpertsReader,
) )
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.budget import MultiModalBudget
from vllm.v1.core.encoder_cache_manager import ( from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager, EncoderCacheManager,
EncoderDecoderCacheManager, EncoderDecoderCacheManager,
compute_encoder_budget,
) )
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
@@ -174,22 +174,22 @@ class Scheduler(SchedulerInterface):
# Encoder-related. # Encoder-related.
# Calculate encoder cache size if applicable # Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space. self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(
# This can be changed when we make encoder cache for embedding caching vllm_config.model_config
# across requests. )
encoder_compute_budget, encoder_cache_size = compute_encoder_budget( self.mm_budget = mm_budget = (
model_config=vllm_config.model_config, MultiModalBudget(vllm_config, mm_registry)
scheduler_config=vllm_config.scheduler_config, if self.supports_mm_inputs
mm_registry=mm_registry, else None
) )
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and # NOTE: Text-only encoder-decoder models are implemented as
# projector if needed) for MM models as well as encoder-decoder # multi-modal models for convenience
# transformers. # Example: https://github.com/neuralmagic/bart-plugin
self.max_num_encoder_input_tokens = encoder_compute_budget self.max_num_encoder_input_tokens = (
# NOTE: For the models without encoder (e.g., text-only models), mm_budget.encoder_compute_budget if mm_budget else 0
# the encoder cache will not be initialized because cache size is 0 )
# for these models. encoder_cache_size = mm_budget.encoder_cache_size if mm_budget else 0
self.encoder_cache_manager = ( self.encoder_cache_manager = (
EncoderDecoderCacheManager(cache_size=encoder_cache_size) EncoderDecoderCacheManager(cache_size=encoder_cache_size)
if self.is_encoder_decoder if self.is_encoder_decoder
@@ -199,7 +199,9 @@ class Scheduler(SchedulerInterface):
# Attn blocks, as for Whisper its input is always padded to the maximum length. # Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs. # TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self._num_encoder_max_input_tokens = ( self._num_encoder_max_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config) mm_budget.mm_max_toks_per_item[mm_budget.get_modality_with_max_tokens()]
if mm_budget
else 0
) )
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config

View File

@@ -19,6 +19,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.budget import MultiModalBudget
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalDataDict, MultiModalDataDict,
MultiModalFeatureSpec, MultiModalFeatureSpec,
@@ -34,7 +35,6 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
@@ -59,32 +59,30 @@ class InputProcessor:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None: ) -> None:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.structured_outputs_config = vllm_config.structured_outputs_config self.structured_outputs_config = vllm_config.structured_outputs_config
self.observability_config = vllm_config.observability_config
self.generation_config_fields = self.model_config.try_get_generation_config() self.generation_config_fields = model_config.try_get_generation_config()
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.mm_encoder_cache_size = None
if (
self.mm_registry.supports_multimodal_inputs(self.model_config)
and not self.model_config.skip_tokenizer_init
):
with set_default_torch_num_threads():
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_modality(self.model_config)
)
_, self.mm_encoder_cache_size = compute_mm_encoder_budget( self.mm_encoder_cache_size: int | None = None
self.vllm_config.scheduler_config, max_tokens_by_modality if (
) mm_registry.supports_multimodal_inputs(model_config)
and not model_config.skip_tokenizer_init
):
mm_budget = MultiModalBudget(vllm_config, mm_registry)
self.mm_encoder_cache_size = mm_budget.encoder_cache_size
mm_budget.reset_cache() # Not used anymore
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, model_config,
vllm_config.observability_config, self.observability_config,
mm_registry, mm_registry,
mm_processor_cache=self.mm_processor_cache, mm_processor_cache=self.mm_processor_cache,
) )

View File

@@ -82,6 +82,7 @@ from vllm.model_executor.models.interfaces_base import (
is_text_generation_model, is_text_generation_model,
) )
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.budget import MultiModalBudget
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
BatchedTensorInputs, BatchedTensorInputs,
MultiModalKwargsItem, MultiModalKwargsItem,
@@ -180,7 +181,6 @@ from vllm.v1.worker.workspace import lock_workspace
from .utils import ( from .utils import (
AttentionGroup, AttentionGroup,
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, bind_kv_cache,
sanity_check_mm_encoder_outputs, sanity_check_mm_encoder_outputs,
@@ -5104,7 +5104,7 @@ class GPUModelRunner(
# modality with the max possible input tokens even when # modality with the max possible input tokens even when
# it supports multiple. # it supports multiple.
dummy_modality = mm_budget.get_modality_with_max_tokens() dummy_modality = mm_budget.get_modality_with_max_tokens()
max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[
dummy_modality dummy_modality
] ]

View File

@@ -11,125 +11,14 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__) logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.mm_registry = mm_registry
self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]