[Misc] Avoid direct access of global mm_registry in compute_encoder_budget (#15621)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-28 01:52:00 +08:00
committed by GitHub
parent 66aa4c0bf4
commit 13ac9cab21
4 changed files with 19 additions and 7 deletions

View File

@@ -3,7 +3,7 @@
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
@@ -67,6 +67,7 @@ class EncoderCacheManager:
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
@@ -74,6 +75,7 @@ def compute_encoder_budget(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
@@ -89,7 +91,11 @@ def compute_encoder_budget(
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(model_config, scheduler_config)
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
@@ -97,6 +103,7 @@ def compute_encoder_budget(
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
@@ -104,6 +111,7 @@ def _compute_encoder_budget_multimodal(
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
@@ -112,8 +120,8 @@ def _compute_encoder_budget_multimodal(
in the input sequence.
"""
max_tokens_by_modality_dict = MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality( # noqa: E501
model_config)
max_tokens_by_modality_dict = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning(