From 21997f45b10c17f44276cf3872e5f85c61dc7dfd Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 1 Feb 2026 17:18:11 +0800 Subject: [PATCH] [Redo] #33110 with threading limit (#33502) Signed-off-by: DarkLight1337 Co-authored-by: YunzhuLu --- vllm/utils/torch_utils.py | 37 +++++++++++++++++++++----- vllm/v1/engine/input_processor.py | 44 ++++++++++++++++++++++++------- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index a210a9266..08321bf69 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -17,6 +17,7 @@ from packaging.version import Version from torch.library import Library, infer_schema import vllm.envs as envs +from vllm.logger import init_logger if TYPE_CHECKING: from vllm.config import ModelConfig @@ -25,9 +26,7 @@ else: ModelConfig = object IntermediateTensors = object -import logging - -logger = logging.getLogger(__name__) +logger = init_logger(__name__) STR_DTYPE_TO_TORCH_DTYPE = { @@ -104,12 +103,36 @@ def set_default_torch_dtype(dtype: torch.dtype): @contextlib.contextmanager -def set_default_torch_num_threads(num_threads: int): - """Sets the default number of threads for PyTorch to the given value.""" +def set_default_torch_num_threads(num_threads: int | None = None): + """ + Sets the default number of threads for PyTorch to the given value. + + `None` means using the value of the environment variable `OMP_NUM_THREADS` + (or `1` if that is not available). + """ + if num_threads is None: + num_threads = 1 + + try: + num_threads = int(os.environ["OMP_NUM_THREADS"]) + except KeyError: + logger.debug_once( + "OMP_NUM_THREADS is not set; defaulting Torch threads to %d.", + num_threads, + ) + except ValueError: + logger.warning_once( + "OMP_NUM_THREADS is invalid; defaulting Torch threads to %d.", + num_threads, + ) + old_num_threads = torch.get_num_threads() torch.set_num_threads(num_threads) - yield - torch.set_num_threads(old_num_threads) + + try: + yield + finally: + torch.set_num_threads(old_num_threads) @contextlib.contextmanager diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index 9541db18b..893acce5a 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os import time from collections.abc import Mapping from typing import Any, Literal, cast @@ -35,6 +34,7 @@ from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import MistralTokenizer from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid from vllm.utils.torch_utils import set_default_torch_num_threads +from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.engine import EngineCoreRequest from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.structured_output.backend_guidance import ( @@ -68,6 +68,19 @@ class InputProcessor: self.mm_registry = mm_registry self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config) + self.mm_encoder_cache_size = None + if ( + self.mm_registry.supports_multimodal_inputs(self.model_config) + and not self.model_config.skip_tokenizer_init + ): + with set_default_torch_num_threads(): + max_tokens_by_modality = ( + mm_registry.get_max_tokens_per_item_by_modality(self.model_config) + ) + + _, self.mm_encoder_cache_size = compute_mm_encoder_budget( + self.vllm_config.scheduler_config, max_tokens_by_modality + ) self.input_preprocessor = InputPreprocessor( self.model_config, @@ -534,15 +547,7 @@ class InputProcessor: # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - num_threads = int(os.environ.get("OMP_NUM_THREADS", "1")) - if "OMP_NUM_THREADS" not in os.environ: - logger.debug_once( - "OMP_NUM_THREADS is not set; defaulting Torch threads to %d for " - "input preprocessing.", - num_threads, - ) - - with set_request_id(request_id), set_default_torch_num_threads(num_threads): + with set_request_id(request_id), set_default_torch_num_threads(): processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, @@ -743,6 +748,25 @@ class InputProcessor: f"model length of {max_prompt_len}. {suggestion}" ) + if ( + prompt_type == "decoder" + and prompt_inputs["type"] == "multimodal" + and self.mm_encoder_cache_size is not None + ): + decoder_mm_positions = prompt_inputs["mm_placeholders"] + for modality, mm_positions in decoder_mm_positions.items(): + for mm_position in mm_positions: + embed_length = mm_position.get_num_embeds + if embed_length > self.mm_encoder_cache_size: + raise ValueError( + f"The {prompt_type} prompt contains a(n) {modality} item " + f"with length {embed_length}, which exceeds the " + f"pre-allocated encoder cache size " + f"{self.mm_encoder_cache_size}. Please reduce the input " + f"size or increase the encoder cache size " + f"by setting --limit-mm-per-prompt at startup." + ) + def stat_mm_cache(self) -> MultiModalCacheStats | None: return self.input_preprocessor.stat_mm_cache()