[Refactor] Move profiling methods to MM budget (#33559)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,7 +10,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.lora import LoRAConfig, ModelConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (
|
||||
BaseLayerWithLoRA,
|
||||
@@ -35,9 +35,9 @@ from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.budget import MultiModalBudget
|
||||
from vllm.utils.cache import LRUCache
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.worker.utils import MultiModalBudget
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -142,7 +142,6 @@ class LoRAModelManager:
|
||||
vllm_config: VllmConfig,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
model_config: ModelConfig = vllm_config.model_config
|
||||
mm_registry = MULTIMODAL_REGISTRY
|
||||
|
||||
self.supports_tower_connector_lora = False
|
||||
@@ -162,7 +161,6 @@ class LoRAModelManager:
|
||||
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
|
||||
|
||||
if self.lora_config.enable_tower_connector_lora:
|
||||
self.mm_processor_info = mm_registry.create_processor(model_config).info
|
||||
self.supports_tower_connector_lora = self.supports_mm and hasattr(
|
||||
self.model, "get_num_mm_encoder_tokens"
|
||||
)
|
||||
@@ -176,7 +174,7 @@ class LoRAModelManager:
|
||||
)
|
||||
|
||||
mm_budget = MultiModalBudget(vllm_config, mm_registry)
|
||||
limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values())
|
||||
limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
|
||||
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
|
||||
mm_budget.get_encoder_budget()
|
||||
)
|
||||
|
||||
154
vllm/multimodal/budget.py
Normal file
154
vllm/multimodal/budget.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
|
||||
|
||||
def get_mm_max_toks_per_item(
|
||||
model_config: ModelConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
processor: BaseMultiModalProcessor,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return max_tokens_per_item
|
||||
|
||||
mm_inputs = mm_registry.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=mm_counts,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in mm_inputs["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
cache = mm_registry.processor_only_cache_from_config(vllm_config)
|
||||
processor = mm_registry.create_processor(model_config, cache=cache)
|
||||
|
||||
self.cache = cache
|
||||
self.mm_limits = mm_limits = processor.info.allowed_mm_limits
|
||||
|
||||
active_modalities = {
|
||||
modality for modality, limit in mm_limits.items() if limit > 0
|
||||
}
|
||||
|
||||
with set_default_torch_num_threads(): # Avoid hang during startup
|
||||
all_mm_max_toks_per_item = get_mm_max_toks_per_item(
|
||||
model_config,
|
||||
mm_registry,
|
||||
processor,
|
||||
mm_counts=dict.fromkeys(active_modalities, 1),
|
||||
)
|
||||
|
||||
mm_max_toks_per_item = {
|
||||
modality: all_mm_max_toks_per_item[modality]
|
||||
for modality in active_modalities
|
||||
}
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
mm_max_toks_per_item,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
mm_max_items_per_prompt = dict[str, int]()
|
||||
mm_max_items_per_batch = dict[str, int]()
|
||||
|
||||
for modality, max_toks_per_item in mm_max_toks_per_item.items():
|
||||
(
|
||||
mm_max_items_per_prompt[modality],
|
||||
mm_max_items_per_batch[modality],
|
||||
) = self._get_max_items(modality, max_toks_per_item)
|
||||
|
||||
self.mm_max_toks_per_item = mm_max_toks_per_item
|
||||
self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt
|
||||
self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch
|
||||
|
||||
def _get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
if (encoder_budget := self.get_encoder_budget()) == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
mm_max_toks_per_item = self.mm_max_toks_per_item
|
||||
modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: x[1])
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
if self.cache is not None:
|
||||
self.cache.clear_cache()
|
||||
@@ -130,15 +130,13 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
return False
|
||||
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
supported_modalities = info.get_supported_mm_limits()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
info = self._create_processing_info(model_config, tokenizer=None)
|
||||
|
||||
# Check if all supported modalities have limit == 0
|
||||
if all(
|
||||
mm_config.get_limit_per_prompt(modality) == 0
|
||||
for modality in supported_modalities
|
||||
for modality in info.supported_mm_limits
|
||||
):
|
||||
logger.info_once(
|
||||
"All limits of multimodal modalities supported by the model "
|
||||
@@ -148,70 +146,6 @@ class MultiModalRegistry:
|
||||
|
||||
return True
|
||||
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
profiler_limits: Mapping[str, int] | None = None,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
on underlying model configuration.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
|
||||
if profiler_limits is None:
|
||||
profiler_limits = processor.info.allowed_mm_limits
|
||||
|
||||
mm_counts = {
|
||||
modality: 1 for modality, limit in profiler_limits.items() if limit > 0
|
||||
}
|
||||
|
||||
max_tokens_per_item = processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
if max_tokens_per_item is not None:
|
||||
return {
|
||||
modality: max_tokens
|
||||
for modality, max_tokens in max_tokens_per_item.items()
|
||||
if mm_counts.get(modality, 0) > 0
|
||||
}
|
||||
|
||||
mm_inputs = self.get_dummy_mm_inputs(
|
||||
model_config,
|
||||
mm_counts=mm_counts,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
return {
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in mm_inputs["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
that are allowed per prompt for a model class.
|
||||
"""
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
info = self._create_processing_info(model_config, observability_config)
|
||||
return info.allowed_mm_limits
|
||||
|
||||
def register_processor(
|
||||
self,
|
||||
processor: MultiModalProcessorFactory[_I],
|
||||
@@ -303,10 +237,9 @@ class MultiModalRegistry:
|
||||
def get_dummy_mm_inputs(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
cache: BaseMultiModalProcessorCache | None = None,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
processor: BaseMultiModalProcessor | None = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
@@ -317,11 +250,7 @@ class MultiModalRegistry:
|
||||
seq_len = model_config.max_model_len
|
||||
|
||||
if processor is None:
|
||||
processor = self.create_processor(
|
||||
model_config, observability_config, cache=cache
|
||||
)
|
||||
if mm_counts is None:
|
||||
mm_counts = processor.info.allowed_mm_limits
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
|
||||
processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
|
||||
seq_len=seq_len,
|
||||
@@ -342,26 +271,6 @@ class MultiModalRegistry:
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum length of the encoder input for encoder-decoder models.
|
||||
"""
|
||||
if not model_config.is_encoder_decoder:
|
||||
return 0
|
||||
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
|
||||
if not max_tokens:
|
||||
# TODO - this function assumes encoder-decoder models are
|
||||
# multimodal. This will need to change when adding support for more
|
||||
# than whisper.
|
||||
return 0
|
||||
assert len(max_tokens) == 1, (
|
||||
"Encoder-decoder models are expected "
|
||||
"to implement the multimodal interface with at most one modality."
|
||||
)
|
||||
|
||||
first_modality = next(iter(max_tokens))
|
||||
return max_tokens[first_modality]
|
||||
|
||||
def _get_cache_type(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
|
||||
@@ -6,11 +6,10 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -267,60 +266,16 @@ class EncoderCacheManager:
|
||||
return freed
|
||||
|
||||
|
||||
def compute_encoder_budget(
|
||||
model_config: "ModelConfig",
|
||||
scheduler_config: "SchedulerConfig",
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
if mm_registry.supports_multimodal_inputs(model_config):
|
||||
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||
model_config
|
||||
)
|
||||
|
||||
return compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
return compute_text_encoder_budget(scheduler_config)
|
||||
|
||||
|
||||
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a text-only model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
"""
|
||||
# Currently text-only encoder-decoder models are not supported
|
||||
return 0, 0
|
||||
|
||||
|
||||
def compute_mm_encoder_budget(
|
||||
scheduler_config: "SchedulerConfig",
|
||||
max_tokens_by_modality: Mapping[str, int],
|
||||
mm_max_toks_per_item: Mapping[str, int],
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a multimodal model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
max_tokens_by_modality: The maximum number of tokens for each
|
||||
mm_max_toks_per_item: The maximum number of tokens per item for each
|
||||
non-text modality.
|
||||
|
||||
Returns:
|
||||
@@ -330,7 +285,7 @@ def compute_mm_encoder_budget(
|
||||
from the input sequence.
|
||||
"""
|
||||
|
||||
if not max_tokens_by_modality:
|
||||
if not mm_max_toks_per_item:
|
||||
logger.warning(
|
||||
"All non-text modalities supported by the model have been "
|
||||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
|
||||
@@ -338,7 +293,7 @@ def compute_mm_encoder_budget(
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
|
||||
max_tokens_per_mm_item = max(mm_max_toks_per_item.values())
|
||||
|
||||
if (
|
||||
scheduler_config.disable_chunked_mm_input
|
||||
|
||||
@@ -31,10 +31,10 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import (
|
||||
RoutedExpertsReader,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.budget import MultiModalBudget
|
||||
from vllm.v1.core.encoder_cache_manager import (
|
||||
EncoderCacheManager,
|
||||
EncoderDecoderCacheManager,
|
||||
compute_encoder_budget,
|
||||
)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
|
||||
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
|
||||
@@ -174,22 +174,22 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Encoder-related.
|
||||
# Calculate encoder cache size if applicable
|
||||
# NOTE: For now we use the same budget for both compute and space.
|
||||
# This can be changed when we make encoder cache for embedding caching
|
||||
# across requests.
|
||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
mm_registry=mm_registry,
|
||||
self.supports_mm_inputs = mm_registry.supports_multimodal_inputs(
|
||||
vllm_config.model_config
|
||||
)
|
||||
self.mm_budget = mm_budget = (
|
||||
MultiModalBudget(vllm_config, mm_registry)
|
||||
if self.supports_mm_inputs
|
||||
else None
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Here, "encoder" includes the vision encoder (and
|
||||
# projector if needed) for MM models as well as encoder-decoder
|
||||
# transformers.
|
||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
||||
# NOTE: For the models without encoder (e.g., text-only models),
|
||||
# the encoder cache will not be initialized because cache size is 0
|
||||
# for these models.
|
||||
# NOTE: Text-only encoder-decoder models are implemented as
|
||||
# multi-modal models for convenience
|
||||
# Example: https://github.com/neuralmagic/bart-plugin
|
||||
self.max_num_encoder_input_tokens = (
|
||||
mm_budget.encoder_compute_budget if mm_budget else 0
|
||||
)
|
||||
encoder_cache_size = mm_budget.encoder_cache_size if mm_budget else 0
|
||||
self.encoder_cache_manager = (
|
||||
EncoderDecoderCacheManager(cache_size=encoder_cache_size)
|
||||
if self.is_encoder_decoder
|
||||
@@ -199,7 +199,9 @@ class Scheduler(SchedulerInterface):
|
||||
# Attn blocks, as for Whisper its input is always padded to the maximum length.
|
||||
# TODO (NickLucche): Generalize to models with variable-length encoder inputs.
|
||||
self._num_encoder_max_input_tokens = (
|
||||
MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config)
|
||||
mm_budget.mm_max_toks_per_item[mm_budget.get_modality_with_max_tokens()]
|
||||
if mm_budget
|
||||
else 0
|
||||
)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.budget import MultiModalBudget
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
@@ -34,7 +35,6 @@ from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
@@ -59,32 +59,30 @@ class InputProcessor:
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.structured_outputs_config = vllm_config.structured_outputs_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.generation_config_fields = self.model_config.try_get_generation_config()
|
||||
self.generation_config_fields = model_config.try_get_generation_config()
|
||||
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
|
||||
self.mm_encoder_cache_size = None
|
||||
if (
|
||||
self.mm_registry.supports_multimodal_inputs(self.model_config)
|
||||
and not self.model_config.skip_tokenizer_init
|
||||
):
|
||||
with set_default_torch_num_threads():
|
||||
max_tokens_by_modality = (
|
||||
mm_registry.get_max_tokens_per_item_by_modality(self.model_config)
|
||||
)
|
||||
|
||||
_, self.mm_encoder_cache_size = compute_mm_encoder_budget(
|
||||
self.vllm_config.scheduler_config, max_tokens_by_modality
|
||||
)
|
||||
self.mm_encoder_cache_size: int | None = None
|
||||
if (
|
||||
mm_registry.supports_multimodal_inputs(model_config)
|
||||
and not model_config.skip_tokenizer_init
|
||||
):
|
||||
mm_budget = MultiModalBudget(vllm_config, mm_registry)
|
||||
self.mm_encoder_cache_size = mm_budget.encoder_cache_size
|
||||
mm_budget.reset_cache() # Not used anymore
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
vllm_config.observability_config,
|
||||
model_config,
|
||||
self.observability_config,
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
@@ -82,6 +82,7 @@ from vllm.model_executor.models.interfaces_base import (
|
||||
is_text_generation_model,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.budget import MultiModalBudget
|
||||
from vllm.multimodal.inputs import (
|
||||
BatchedTensorInputs,
|
||||
MultiModalKwargsItem,
|
||||
@@ -180,7 +181,6 @@ from vllm.v1.worker.workspace import lock_workspace
|
||||
|
||||
from .utils import (
|
||||
AttentionGroup,
|
||||
MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
@@ -5104,7 +5104,7 @@ class GPUModelRunner(
|
||||
# modality with the max possible input tokens even when
|
||||
# it supports multiple.
|
||||
dummy_modality = mm_budget.get_modality_with_max_tokens()
|
||||
max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[
|
||||
max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[
|
||||
dummy_modality
|
||||
]
|
||||
|
||||
|
||||
@@ -11,125 +11,14 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config = vllm_config.model_config
|
||||
self.scheduler_config = scheduler_config = vllm_config.scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
|
||||
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||
model_config,
|
||||
cache=cache,
|
||||
profiler_limits=self.mm_limits,
|
||||
)
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
self.encoder_compute_budget = encoder_compute_budget
|
||||
self.encoder_cache_size = encoder_cache_size
|
||||
|
||||
max_items_per_prompt_by_modality = dict[str, int]()
|
||||
max_items_per_batch_by_modality = dict[str, int]()
|
||||
|
||||
for modality, max_tokens in max_tokens_by_modality.items():
|
||||
(
|
||||
max_items_per_prompt,
|
||||
max_items_per_batch,
|
||||
) = self.get_max_items(modality, max_tokens)
|
||||
|
||||
max_items_per_prompt_by_modality[modality] = max_items_per_prompt
|
||||
max_items_per_batch_by_modality[modality] = max_items_per_batch
|
||||
|
||||
self.max_tokens_by_modality = max_tokens_by_modality
|
||||
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
|
||||
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
|
||||
|
||||
def get_modality_with_max_tokens(self) -> str:
|
||||
max_tokens_by_modality = self.max_tokens_by_modality
|
||||
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
|
||||
|
||||
return modality
|
||||
|
||||
def get_encoder_budget(self) -> int:
|
||||
return min(self.encoder_compute_budget, self.encoder_cache_size)
|
||||
|
||||
def get_max_items(
|
||||
self,
|
||||
modality: str,
|
||||
max_tokens_per_item: int,
|
||||
) -> tuple[int, int]:
|
||||
if max_tokens_per_item == 0:
|
||||
return 0, 0
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the encoder budget.
|
||||
encoder_budget = self.get_encoder_budget()
|
||||
|
||||
# TODO: handle encoder-decoder models once we support them.
|
||||
if encoder_budget == 0:
|
||||
return 0, 0
|
||||
|
||||
max_encoder_items_per_batch = encoder_budget // max_tokens_per_item
|
||||
|
||||
# Check how many items of this modality can be supported by
|
||||
# the decoder budget.
|
||||
mm_limit = self.mm_limits[modality]
|
||||
|
||||
max_items_per_prompt = max(
|
||||
1,
|
||||
min(mm_limit, self.max_model_len // max_tokens_per_item),
|
||||
)
|
||||
|
||||
scheduler_config = self.scheduler_config
|
||||
max_num_reqs = self.max_num_reqs
|
||||
|
||||
if not scheduler_config.enable_chunked_prefill:
|
||||
max_num_reqs = min(
|
||||
max_num_reqs,
|
||||
scheduler_config.max_num_batched_tokens // max_tokens_per_item,
|
||||
)
|
||||
|
||||
max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt
|
||||
|
||||
max_items_per_batch = max(
|
||||
1,
|
||||
min(max_encoder_items_per_batch, max_decoder_items_per_batch),
|
||||
)
|
||||
|
||||
return max_items_per_prompt, max_items_per_batch
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
if self.cache is not None:
|
||||
self.cache.clear_cache()
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionGroup:
|
||||
backend: type[AttentionBackend]
|
||||
|
||||
Reference in New Issue
Block a user