diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 89bcff3f8..3c112b6ac 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -10,7 +10,7 @@ import torch from torch import nn from vllm.config import VllmConfig -from vllm.config.lora import LoRAConfig, ModelConfig +from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import ( BaseLayerWithLoRA, @@ -35,9 +35,9 @@ from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.budget import MultiModalBudget from vllm.utils.cache import LRUCache from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.worker.utils import MultiModalBudget logger = init_logger(__name__) @@ -142,7 +142,6 @@ class LoRAModelManager: vllm_config: VllmConfig, max_num_batched_tokens: int, ) -> None: - model_config: ModelConfig = vllm_config.model_config mm_registry = MULTIMODAL_REGISTRY self.supports_tower_connector_lora = False @@ -162,7 +161,6 @@ class LoRAModelManager: self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper if self.lora_config.enable_tower_connector_lora: - self.mm_processor_info = mm_registry.create_processor(model_config).info self.supports_tower_connector_lora = self.supports_mm and hasattr( self.model, "get_num_mm_encoder_tokens" ) @@ -176,7 +174,7 @@ class LoRAModelManager: ) mm_budget = MultiModalBudget(vllm_config, mm_registry) - limit_per_prompt = max(self.mm_processor_info.allowed_mm_limits.values()) + limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values()) num_encoder_tokens = self.model.get_num_mm_encoder_tokens( mm_budget.get_encoder_budget() ) diff --git a/vllm/multimodal/budget.py b/vllm/multimodal/budget.py new file mode 100644 index 000000000..0cd2419ca --- /dev/null +++ b/vllm/multimodal/budget.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Mapping + +from vllm.config import ModelConfig, VllmConfig +from vllm.multimodal.processing import BaseMultiModalProcessor +from vllm.multimodal.registry import MultiModalRegistry +from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget + + +def get_mm_max_toks_per_item( + model_config: ModelConfig, + mm_registry: MultiModalRegistry, + processor: BaseMultiModalProcessor, + mm_counts: Mapping[str, int], +) -> Mapping[str, int]: + """ + Get the maximum number of tokens per data item from each modality based + on underlying model configuration. + """ + max_tokens_per_item = processor.info.get_mm_max_tokens_per_item( + seq_len=model_config.max_model_len, + mm_counts=mm_counts, + ) + if max_tokens_per_item is not None: + return max_tokens_per_item + + mm_inputs = mm_registry.get_dummy_mm_inputs( + model_config, + mm_counts=mm_counts, + processor=processor, + ) + + return { + modality: sum(item.get_num_embeds for item in placeholders) + for modality, placeholders in mm_inputs["mm_placeholders"].items() + } + + +class MultiModalBudget: + """Helper class to calculate budget information for multi-modal models.""" + + def __init__( + self, + vllm_config: VllmConfig, + mm_registry: MultiModalRegistry, + ) -> None: + super().__init__() + + self.model_config = model_config = vllm_config.model_config + self.scheduler_config = scheduler_config = vllm_config.scheduler_config + + self.max_model_len = model_config.max_model_len + self.max_num_reqs = scheduler_config.max_num_seqs + + cache = mm_registry.processor_only_cache_from_config(vllm_config) + processor = mm_registry.create_processor(model_config, cache=cache) + + self.cache = cache + self.mm_limits = mm_limits = processor.info.allowed_mm_limits + + active_modalities = { + modality for modality, limit in mm_limits.items() if limit > 0 + } + + with set_default_torch_num_threads(): # Avoid hang during startup + all_mm_max_toks_per_item = get_mm_max_toks_per_item( + model_config, + mm_registry, + processor, + mm_counts=dict.fromkeys(active_modalities, 1), + ) + + mm_max_toks_per_item = { + modality: all_mm_max_toks_per_item[modality] + for modality in active_modalities + } + + encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( + scheduler_config, + mm_max_toks_per_item, + ) + + self.encoder_compute_budget = encoder_compute_budget + self.encoder_cache_size = encoder_cache_size + + mm_max_items_per_prompt = dict[str, int]() + mm_max_items_per_batch = dict[str, int]() + + for modality, max_toks_per_item in mm_max_toks_per_item.items(): + ( + mm_max_items_per_prompt[modality], + mm_max_items_per_batch[modality], + ) = self._get_max_items(modality, max_toks_per_item) + + self.mm_max_toks_per_item = mm_max_toks_per_item + self.mm_max_items_per_prompt: Mapping[str, int] = mm_max_items_per_prompt + self.mm_max_items_per_batch: Mapping[str, int] = mm_max_items_per_batch + + def _get_max_items( + self, + modality: str, + max_tokens_per_item: int, + ) -> tuple[int, int]: + if max_tokens_per_item == 0: + return 0, 0 + + # Check how many items of this modality can be supported by + # the encoder budget. + if (encoder_budget := self.get_encoder_budget()) == 0: + return 0, 0 + + max_encoder_items_per_batch = encoder_budget // max_tokens_per_item + + # Check how many items of this modality can be supported by + # the decoder budget. + mm_limit = self.mm_limits[modality] + + max_items_per_prompt = max( + 1, + min(mm_limit, self.max_model_len // max_tokens_per_item), + ) + + scheduler_config = self.scheduler_config + max_num_reqs = self.max_num_reqs + + if not scheduler_config.enable_chunked_prefill: + max_num_reqs = min( + max_num_reqs, + scheduler_config.max_num_batched_tokens // max_tokens_per_item, + ) + + max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt + + max_items_per_batch = max( + 1, + min(max_encoder_items_per_batch, max_decoder_items_per_batch), + ) + + return max_items_per_prompt, max_items_per_batch + + def get_modality_with_max_tokens(self) -> str: + mm_max_toks_per_item = self.mm_max_toks_per_item + modality, _ = max(mm_max_toks_per_item.items(), key=lambda x: x[1]) + + return modality + + def get_encoder_budget(self) -> int: + return min(self.encoder_compute_budget, self.encoder_cache_size) + + def reset_cache(self) -> None: + if self.cache is not None: + self.cache.clear_cache() diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 7fe68b36f..82329d17f 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -130,15 +130,13 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return False - info = self._create_processing_info(model_config, tokenizer=None) - supported_modalities = info.get_supported_mm_limits() - mm_config = model_config.get_multimodal_config() + info = self._create_processing_info(model_config, tokenizer=None) # Check if all supported modalities have limit == 0 if all( mm_config.get_limit_per_prompt(modality) == 0 - for modality in supported_modalities + for modality in info.supported_mm_limits ): logger.info_once( "All limits of multimodal modalities supported by the model " @@ -148,70 +146,6 @@ class MultiModalRegistry: return True - def get_max_tokens_per_item_by_modality( - self, - model_config: "ModelConfig", - *, - cache: BaseMultiModalProcessorCache | None = None, - profiler_limits: Mapping[str, int] | None = None, - observability_config: ObservabilityConfig | None = None, - ) -> Mapping[str, int]: - """ - Get the maximum number of tokens per data item from each modality based - on underlying model configuration. - """ - if not model_config.is_multimodal_model: - return {} - - processor = self.create_processor( - model_config, observability_config, cache=cache - ) - - if profiler_limits is None: - profiler_limits = processor.info.allowed_mm_limits - - mm_counts = { - modality: 1 for modality, limit in profiler_limits.items() if limit > 0 - } - - max_tokens_per_item = processor.info.get_mm_max_tokens_per_item( - seq_len=model_config.max_model_len, - mm_counts=mm_counts, - ) - if max_tokens_per_item is not None: - return { - modality: max_tokens - for modality, max_tokens in max_tokens_per_item.items() - if mm_counts.get(modality, 0) > 0 - } - - mm_inputs = self.get_dummy_mm_inputs( - model_config, - mm_counts=mm_counts, - processor=processor, - ) - - return { - modality: sum(item.get_num_embeds for item in placeholders) - for modality, placeholders in mm_inputs["mm_placeholders"].items() - } - - def get_mm_limits_per_prompt( - self, - model_config: "ModelConfig", - *, - observability_config: ObservabilityConfig | None = None, - ) -> Mapping[str, int]: - """ - Get the maximum number of multi-modal input instances for each modality - that are allowed per prompt for a model class. - """ - if not model_config.is_multimodal_model: - return {} - - info = self._create_processing_info(model_config, observability_config) - return info.allowed_mm_limits - def register_processor( self, processor: MultiModalProcessorFactory[_I], @@ -303,10 +237,9 @@ class MultiModalRegistry: def get_dummy_mm_inputs( self, model_config: "ModelConfig", - mm_counts: Mapping[str, int] | None = None, + mm_counts: Mapping[str, int], *, cache: BaseMultiModalProcessorCache | None = None, - observability_config: ObservabilityConfig | None = None, processor: BaseMultiModalProcessor | None = None, ) -> MultiModalInputs: """ @@ -317,11 +250,7 @@ class MultiModalRegistry: seq_len = model_config.max_model_len if processor is None: - processor = self.create_processor( - model_config, observability_config, cache=cache - ) - if mm_counts is None: - mm_counts = processor.info.allowed_mm_limits + processor = self.create_processor(model_config, cache=cache) processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs( seq_len=seq_len, @@ -342,26 +271,6 @@ class MultiModalRegistry: return mm_inputs - def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int: - """ - Get the maximum length of the encoder input for encoder-decoder models. - """ - if not model_config.is_encoder_decoder: - return 0 - max_tokens = self.get_max_tokens_per_item_by_modality(model_config) - if not max_tokens: - # TODO - this function assumes encoder-decoder models are - # multimodal. This will need to change when adding support for more - # than whisper. - return 0 - assert len(max_tokens) == 1, ( - "Encoder-decoder models are expected " - "to implement the multimodal interface with at most one modality." - ) - - first_modality = next(iter(max_tokens)) - return max_tokens[first_modality] - def _get_cache_type( self, vllm_config: "VllmConfig", diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 56f40535e..6f1a2560d 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -6,11 +6,10 @@ from collections.abc import Mapping from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.multimodal import MultiModalRegistry from vllm.v1.request import Request if TYPE_CHECKING: - from vllm.config import ModelConfig, SchedulerConfig + from vllm.config import SchedulerConfig logger = init_logger(__name__) @@ -267,60 +266,16 @@ class EncoderCacheManager: return freed -def compute_encoder_budget( - model_config: "ModelConfig", - scheduler_config: "SchedulerConfig", - mm_registry: MultiModalRegistry, -) -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler - configurations. - - Returns: - - Compute budget for encoder execution, measured in number of tokens - from the input sequence. - - Space budget for encoder cache size, measured in number of tokens - from the input sequence. - """ - if mm_registry.supports_multimodal_inputs(model_config): - max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( - model_config - ) - - return compute_mm_encoder_budget( - scheduler_config, - max_tokens_by_modality, - ) - - return compute_text_encoder_budget(scheduler_config) - - -def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]: - """Compute the encoder cache budget based on the model and scheduler - configurations for a text-only model. - - Args: - scheduler_config: Scheduler configuration. - - Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. - """ - # Currently text-only encoder-decoder models are not supported - return 0, 0 - - def compute_mm_encoder_budget( scheduler_config: "SchedulerConfig", - max_tokens_by_modality: Mapping[str, int], + mm_max_toks_per_item: Mapping[str, int], ) -> tuple[int, int]: """Compute the encoder cache budget based on the model and scheduler configurations for a multimodal model. Args: scheduler_config: Scheduler configuration. - max_tokens_by_modality: The maximum number of tokens for each + mm_max_toks_per_item: The maximum number of tokens per item for each non-text modality. Returns: @@ -330,7 +285,7 @@ def compute_mm_encoder_budget( from the input sequence. """ - if not max_tokens_by_modality: + if not mm_max_toks_per_item: logger.warning( "All non-text modalities supported by the model have been " "explicitly disabled via limit_mm_per_prompt. Encoder cache will " @@ -338,7 +293,7 @@ def compute_mm_encoder_budget( ) return 0, 0 - max_tokens_per_mm_item = max(max_tokens_by_modality.values()) + max_tokens_per_mm_item = max(mm_max_toks_per_item.values()) if ( scheduler_config.disable_chunked_mm_input diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b1667c075..84e6ae1f1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -31,10 +31,10 @@ from vllm.model_executor.layers.fused_moe.routed_experts_capturer import ( RoutedExpertsReader, ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.budget import MultiModalBudget from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, EncoderDecoderCacheManager, - compute_encoder_budget, ) from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector @@ -174,22 +174,22 @@ class Scheduler(SchedulerInterface): # Encoder-related. # Calculate encoder cache size if applicable - # NOTE: For now we use the same budget for both compute and space. - # This can be changed when we make encoder cache for embedding caching - # across requests. - encoder_compute_budget, encoder_cache_size = compute_encoder_budget( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - mm_registry=mm_registry, + self.supports_mm_inputs = mm_registry.supports_multimodal_inputs( + vllm_config.model_config + ) + self.mm_budget = mm_budget = ( + MultiModalBudget(vllm_config, mm_registry) + if self.supports_mm_inputs + else None ) - # NOTE(woosuk): Here, "encoder" includes the vision encoder (and - # projector if needed) for MM models as well as encoder-decoder - # transformers. - self.max_num_encoder_input_tokens = encoder_compute_budget - # NOTE: For the models without encoder (e.g., text-only models), - # the encoder cache will not be initialized because cache size is 0 - # for these models. + # NOTE: Text-only encoder-decoder models are implemented as + # multi-modal models for convenience + # Example: https://github.com/neuralmagic/bart-plugin + self.max_num_encoder_input_tokens = ( + mm_budget.encoder_compute_budget if mm_budget else 0 + ) + encoder_cache_size = mm_budget.encoder_cache_size if mm_budget else 0 self.encoder_cache_manager = ( EncoderDecoderCacheManager(cache_size=encoder_cache_size) if self.is_encoder_decoder @@ -199,7 +199,9 @@ class Scheduler(SchedulerInterface): # Attn blocks, as for Whisper its input is always padded to the maximum length. # TODO (NickLucche): Generalize to models with variable-length encoder inputs. self._num_encoder_max_input_tokens = ( - MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(vllm_config.model_config) + mm_budget.mm_max_toks_per_item[mm_budget.get_modality_with_max_tokens()] + if mm_budget + else 0 ) speculative_config = vllm_config.speculative_config diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 893acce5a..b4de0e50c 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -19,6 +19,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.budget import MultiModalBudget from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, @@ -34,7 +35,6 @@ from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils.torch_utils import set_default_torch_num_threads -from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import ( @@ -59,32 +59,30 @@ class InputProcessor: mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, ) -> None: self.vllm_config = vllm_config - self.model_config = vllm_config.model_config + self.model_config = model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config self.lora_config = vllm_config.lora_config + self.scheduler_config = vllm_config.scheduler_config self.structured_outputs_config = vllm_config.structured_outputs_config + self.observability_config = vllm_config.observability_config - self.generation_config_fields = self.model_config.try_get_generation_config() + self.generation_config_fields = model_config.try_get_generation_config() self.mm_registry = mm_registry self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) - self.mm_encoder_cache_size = None - if ( - self.mm_registry.supports_multimodal_inputs(self.model_config) - and not self.model_config.skip_tokenizer_init - ): - with set_default_torch_num_threads(): - max_tokens_by_modality = ( - mm_registry.get_max_tokens_per_item_by_modality(self.model_config) - ) - _, self.mm_encoder_cache_size = compute_mm_encoder_budget( - self.vllm_config.scheduler_config, max_tokens_by_modality - ) + self.mm_encoder_cache_size: int | None = None + if ( + mm_registry.supports_multimodal_inputs(model_config) + and not model_config.skip_tokenizer_init + ): + mm_budget = MultiModalBudget(vllm_config, mm_registry) + self.mm_encoder_cache_size = mm_budget.encoder_cache_size + mm_budget.reset_cache() # Not used anymore self.input_preprocessor = InputPreprocessor( - self.model_config, - vllm_config.observability_config, + model_config, + self.observability_config, mm_registry, mm_processor_cache=self.mm_processor_cache, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 061ac8680..2a1809976 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -82,6 +82,7 @@ from vllm.model_executor.models.interfaces_base import ( is_text_generation_model, ) from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.budget import MultiModalBudget from vllm.multimodal.inputs import ( BatchedTensorInputs, MultiModalKwargsItem, @@ -180,7 +181,6 @@ from vllm.v1.worker.workspace import lock_workspace from .utils import ( AttentionGroup, - MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, sanity_check_mm_encoder_outputs, @@ -5104,7 +5104,7 @@ class GPUModelRunner( # modality with the max possible input tokens even when # it supports multiple. dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ + max_mm_items_per_batch = mm_budget.mm_max_items_per_batch[ dummy_modality ] diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 4a543c7c2..f13c75a7a 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -11,125 +11,14 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index -from vllm.multimodal.registry import MultiModalRegistry from vllm.platforms import current_platform from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.v1.attention.backend import AttentionBackend, AttentionMetadataBuilder -from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec logger = init_logger(__name__) -class MultiModalBudget: - """Helper class to calculate budget information for multi-modal models.""" - - def __init__( - self, - vllm_config: VllmConfig, - mm_registry: MultiModalRegistry, - ) -> None: - super().__init__() - - self.model_config = model_config = vllm_config.model_config - self.scheduler_config = scheduler_config = vllm_config.scheduler_config - self.mm_registry = mm_registry - self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config) - - self.max_model_len = model_config.max_model_len - self.max_num_reqs = scheduler_config.max_num_seqs - - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) - - max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality( - model_config, - cache=cache, - profiler_limits=self.mm_limits, - ) - - encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( - scheduler_config, - max_tokens_by_modality, - ) - - self.encoder_compute_budget = encoder_compute_budget - self.encoder_cache_size = encoder_cache_size - - max_items_per_prompt_by_modality = dict[str, int]() - max_items_per_batch_by_modality = dict[str, int]() - - for modality, max_tokens in max_tokens_by_modality.items(): - ( - max_items_per_prompt, - max_items_per_batch, - ) = self.get_max_items(modality, max_tokens) - - max_items_per_prompt_by_modality[modality] = max_items_per_prompt - max_items_per_batch_by_modality[modality] = max_items_per_batch - - self.max_tokens_by_modality = max_tokens_by_modality - self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality - self.max_items_per_batch_by_modality = max_items_per_batch_by_modality - - def get_modality_with_max_tokens(self) -> str: - max_tokens_by_modality = self.max_tokens_by_modality - modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1]) - - return modality - - def get_encoder_budget(self) -> int: - return min(self.encoder_compute_budget, self.encoder_cache_size) - - def get_max_items( - self, - modality: str, - max_tokens_per_item: int, - ) -> tuple[int, int]: - if max_tokens_per_item == 0: - return 0, 0 - - # Check how many items of this modality can be supported by - # the encoder budget. - encoder_budget = self.get_encoder_budget() - - # TODO: handle encoder-decoder models once we support them. - if encoder_budget == 0: - return 0, 0 - - max_encoder_items_per_batch = encoder_budget // max_tokens_per_item - - # Check how many items of this modality can be supported by - # the decoder budget. - mm_limit = self.mm_limits[modality] - - max_items_per_prompt = max( - 1, - min(mm_limit, self.max_model_len // max_tokens_per_item), - ) - - scheduler_config = self.scheduler_config - max_num_reqs = self.max_num_reqs - - if not scheduler_config.enable_chunked_prefill: - max_num_reqs = min( - max_num_reqs, - scheduler_config.max_num_batched_tokens // max_tokens_per_item, - ) - - max_decoder_items_per_batch = max_num_reqs * max_items_per_prompt - - max_items_per_batch = max( - 1, - min(max_encoder_items_per_batch, max_decoder_items_per_batch), - ) - - return max_items_per_prompt, max_items_per_batch - - def reset_cache(self) -> None: - if self.cache is not None: - self.cache.clear_cache() - - @dataclass class AttentionGroup: backend: type[AttentionBackend]