[Refactor] Move profiling methods to MM budget (#33559)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-02 23:27:00 +08:00
committed by GitHub
parent 528e9b1490
commit d7e17aaacd
8 changed files with 201 additions and 296 deletions

View File

@@ -10,7 +10,7 @@ import torch
from torch import nn
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig, ModelConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (
BaseLayerWithLoRA,
@@ -35,9 +35,9 @@ from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.budget import MultiModalBudget
from vllm.utils.cache import LRUCache
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.worker.utils import MultiModalBudget
logger = init_logger(__name__)
@@ -142,7 +142,6 @@ class LoRAModelManager:
vllm_config: VllmConfig,
max_num_batched_tokens: int,
) -> None:
model_config: ModelConfig = vllm_config.model_config
mm_registry = MULTIMODAL_REGISTRY
self.supports_tower_connector_lora = False
@@ -162,7 +161,6 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = mm_registry.create_processor(model_config).info
self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens"
)
@@ -176,7 +174,7 @@ class LoRAModelManager:
)
mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values())
limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget()
)

154
vllm/multimodal/budget.py Normal file
View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from vllm.config import ModelConfig, VllmConfig
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.registry import MultiModalRegistry
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
def get_mm_max_toks_per_item(
model_config: ModelConfig,
mm_registry: MultiModalRegistry,
processor: BaseMultiModalProcessor,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
mm_inputs = mm_registry.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
processor=processor,
)
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
cache = mm_registry.processor_only_cache_from_config(vllm_config)
processor = mm_registry.create_processor(model_config, cache=cache)
self.cache = cache
self.mm_limits = mm_limits = processor.info.allowed_mm_limits
active_modalities = {
modality for modality, limit in mm_limits.items() if limit > 0
}
with set_default_torch_num_threads(): # Avoid hang during startup
all_mm_max_toks_per_item = get_mm_max_toks_per_item(
model_config,
mm_registry,
processor,
mm_counts=dict.fromkeys(active_modalities, 1),
)
mm_max_toks_per_item = {
modality: all_mm_max_toks_per_item[modality]
for modality in active_modalities
}
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
mm_max_toks_per_item,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
mm_max_items_per_prompt = dict[str, int]()
mm_max_items_per_batch = dict[str, int]()
for modality, max_toks_per_item in mm_max_toks_per_item.items():
(
mm_max_items_per_prompt[modality],
mm_max_items_per_batch[modality],
) = self._get_max_items(modality, max_toks_per_item)
self.mm_max_toks_per_item = mm_max_toks_per_item
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
def _get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
if (encoder_budget := self.get_encoder_budget()) == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def get_modality_with_max_tokens(self) -> str:
mm_max_toks_per_item = self.mm_max_toks_per_item
modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()

View File

@@ -130,15 +130,13 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return False
info = self._create_processing_info(model_config, tokenizer=None)
supported_modalities = info.get_supported_mm_limits()
mm_config = model_config.get_multimodal_config()
info = self._create_processing_info(model_config, tokenizer=None)
# Check if all supported modalities have limit == 0
if all(
mm_config.get_limit_per_prompt(modality) == 0
for modality in supported_modalities
for modality in info.supported_mm_limits
):
logger.info_once(
"All limits of multimodal modalities supported by the model "
@@ -148,70 +146,6 @@ class MultiModalRegistry:
return True
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration.
"""
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(
model_config, observability_config, cache=cache
)
if profiler_limits is None:
profiler_limits = processor.info.allowed_mm_limits
mm_counts = {
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
}
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}
mm_inputs = self.get_dummy_mm_inputs(
model_config,
mm_counts=mm_counts,
processor=processor,
)
return {
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in mm_inputs["mm_placeholders"].items()
}
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
*,
observability_config: ObservabilityConfig | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
that are allowed per prompt for a model class.
"""
if not model_config.is_multimodal_model:
return {}
info = self._create_processing_info(model_config, observability_config)
return info.allowed_mm_limits
def register_processor(
self,
processor: MultiModalProcessorFactory[_I],
@@ -303,10 +237,9 @@ class MultiModalRegistry:
def get_dummy_mm_inputs(
self,
model_config: "ModelConfig",
mm_counts: Mapping[str, int] | None = None,
mm_counts: Mapping[str, int],
*,
cache: BaseMultiModalProcessorCache | None = None,
observability_config: ObservabilityConfig | None = None,
processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
"""
@@ -317,11 +250,7 @@ class MultiModalRegistry:
seq_len = model_config.max_model_len
if processor is None:
processor = self.create_processor(
model_config, observability_config, cache=cache
)
if mm_counts is None:
mm_counts = processor.info.allowed_mm_limits
processor = self.create_processor(model_config, cache=cache)
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
seq_len=seq_len,
@@ -342,26 +271,6 @@ class MultiModalRegistry:
return mm_inputs
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
Get the maximum length of the encoder input for encoder-decoder models.
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
# than whisper.
return 0
assert len(max_tokens) == 1, (
"Encoder-decoder models are expected "
"to implement the multimodal interface with at most one modality."
)
first_modality = next(iter(max_tokens))
return max_tokens[first_modality]
def _get_cache_type(
self,
vllm_config: "VllmConfig",

View File

@@ -6,11 +6,10 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
from vllm.config import SchedulerConfig
logger = init_logger(__name__)
@@ -267,60 +266,16 @@ class EncoderCacheManager:
return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Returns:
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)
return compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
return compute_text_encoder_budget(scheduler_config)
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a text-only model.
Args:
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
# Currently text-only encoder-decoder models are not supported
return 0, 0
def compute_mm_encoder_budget(
scheduler_config: "SchedulerConfig",
max_tokens_by_modality: Mapping[str, int],
mm_max_toks_per_item: Mapping[str, int],
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
scheduler_config: Scheduler configuration.
max_tokens_by_modality: The maximum number of tokens for each
mm_max_toks_per_item: The maximum number of tokens per item for each
non-text modality.
Returns:
@@ -330,7 +285,7 @@ def compute_mm_encoder_budget(
from the input sequence.
"""
if not max_tokens_by_modality:
if not mm_max_toks_per_item:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
@@ -338,7 +293,7 @@ def compute_mm_encoder_budget(
)
return 0, 0
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
max_tokens_per_mm_item = max(mm_max_toks_per_item.values())
if (
scheduler_config.disable_chunked_mm_input

View File

@@ -31,10 +31,10 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
RoutedExpertsReader,
)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.budget import MultiModalBudget
from vllm.v1.core.encoder_cache_manager import (
EncoderCacheManager,
EncoderDecoderCacheManager,
compute_encoder_budget,
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
@@ -174,22 +174,22 @@ class Scheduler(SchedulerInterface):
# Encoder-related.
# Calculate encoder cache size if applicable
# NOTE: For now we use the same budget for both compute and space.
# This can be changed when we make encoder cache for embedding caching
# across requests.
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
mm_registry=mm_registry,
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(
vllm_config.model_config
)
self.mm_budget = mm_budget = (
MultiModalBudget(vllm_config, mm_registry)
if self.supports_mm_inputs
else None
)
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
# projector if needed) for MM models as well as encoder-decoder
# transformers.
self.max_num_encoder_input_tokens = encoder_compute_budget
# NOTE: For the models without encoder (e.g., text-only models),
# the encoder cache will not be initialized because cache size is 0
# for these models.
# NOTE: Text-only encoder-decoder models are implemented as
# multi-modal models for convenience
# Example: https://github.com/neuralmagic/bart-plugin
self.max_num_encoder_input_tokens = (
mm_budget.encoder_compute_budget if mm_budget else 0
)
encoder_cache_size = mm_budget.encoder_cache_size if mm_budget else 0
self.encoder_cache_manager = (
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
if self.is_encoder_decoder
@@ -199,7 +199,9 @@ class Scheduler(SchedulerInterface):
# Attn blocks, as for Whisper its input is always padded to the maximum length.
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
self._num_encoder_max_input_tokens = (
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
mm_budget.mm_max_toks_per_item[mm_budget.get_modality_with_max_tokens()]
if mm_budget
else 0
)
speculative_config = vllm_config.speculative_config

View File

@@ -19,6 +19,7 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.budget import MultiModalBudget
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFeatureSpec,
@@ -34,7 +35,6 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import (
@@ -59,32 +59,30 @@ class InputProcessor:
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.model_config = model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.scheduler_config = vllm_config.scheduler_config
self.structured_outputs_config = vllm_config.structured_outputs_config
self.observability_config = vllm_config.observability_config
self.generation_config_fields = self.model_config.try_get_generation_config()
self.generation_config_fields = model_config.try_get_generation_config()
self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.mm_encoder_cache_size = None
if (
self.mm_registry.supports_multimodal_inputs(self.model_config)
and not self.model_config.skip_tokenizer_init
):
with set_default_torch_num_threads():
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_modality(self.model_config)
)
_, self.mm_encoder_cache_size = compute_mm_encoder_budget(
self.vllm_config.scheduler_config, max_tokens_by_modality
)
self.mm_encoder_cache_size: int | None = None
if (
mm_registry.supports_multimodal_inputs(model_config)
and not model_config.skip_tokenizer_init
):
mm_budget = MultiModalBudget(vllm_config, mm_registry)
self.mm_encoder_cache_size = mm_budget.encoder_cache_size
mm_budget.reset_cache() # Not used anymore
self.input_preprocessor = InputPreprocessor(
self.model_config,
vllm_config.observability_config,
model_config,
self.observability_config,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)

View File

@@ -82,6 +82,7 @@ from vllm.model_executor.models.interfaces_base import (
is_text_generation_model,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.budget import MultiModalBudget
from vllm.multimodal.inputs import (
BatchedTensorInputs,
MultiModalKwargsItem,
@@ -180,7 +181,6 @@ from vllm.v1.worker.workspace import lock_workspace
from .utils import (
AttentionGroup,
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
sanity_check_mm_encoder_outputs,
@@ -5104,7 +5104,7 @@ class GPUModelRunner(
# modality with the max possible input tokens even when
# it supports multiple.
dummy_modality = mm_budget.get_modality_with_max_tokens()
max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[
dummy_modality
]

View File

@@ -11,125 +11,14 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
def __init__(
self,
vllm_config: VllmConfig,
mm_registry: MultiModalRegistry,
) -> None:
super().__init__()
self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.mm_registry = mm_registry
self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
self.encoder_compute_budget = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
max_items_per_prompt_by_modality = dict[str, int]()
max_items_per_batch_by_modality = dict[str, int]()
for modality, max_tokens in max_tokens_by_modality.items():
(
max_items_per_prompt,
max_items_per_batch,
) = self.get_max_items(modality, max_tokens)
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
max_items_per_batch_by_modality[modality] = max_items_per_batch
self.max_tokens_by_modality = max_tokens_by_modality
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
def get_modality_with_max_tokens(self) -> str:
max_tokens_by_modality = self.max_tokens_by_modality
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
return modality
def get_encoder_budget(self) -> int:
return min(self.encoder_compute_budget, self.encoder_cache_size)
def get_max_items(
self,
modality: str,
max_tokens_per_item: int,
) -> tuple[int, int]:
if max_tokens_per_item == 0:
return 0, 0
# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = self.get_encoder_budget()
# TODO: handle encoder-decoder models once we support them.
if encoder_budget == 0:
return 0, 0
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
# Check how many items of this modality can be supported by
# the decoder budget.
mm_limit = self.mm_limits[modality]
max_items_per_prompt = max(
1,
min(mm_limit, self.max_model_len // max_tokens_per_item),
)
scheduler_config = self.scheduler_config
max_num_reqs = self.max_num_reqs
if not scheduler_config.enable_chunked_prefill:
max_num_reqs = min(
max_num_reqs,
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
)
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
max_items_per_batch = max(
1,
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
)
return max_items_per_prompt, max_items_per_batch
def reset_cache(self) -> None:
if self.cache is not None:
self.cache.clear_cache()
@dataclass
class AttentionGroup:
backend: type[AttentionBackend]