[Redo] #33110 with threading limit (#33502)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: YunzhuLu <lucia.yunzhu@gmail.com>
This commit is contained in:
Cyrus Leung
2026-02-01 17:18:11 +08:00
committed by GitHub
parent 672023877b
commit 21997f45b1
2 changed files with 64 additions and 17 deletions

View File

@@ -17,6 +17,7 @@ from packaging.version import Version
from torch.library import Library, infer_schema
import vllm.envs as envs
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -25,9 +26,7 @@ else:
ModelConfig = object
IntermediateTensors = object
import logging
logger = logging.getLogger(__name__)
logger = init_logger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
@@ -104,12 +103,36 @@ def set_default_torch_dtype(dtype: torch.dtype):
@contextlib.contextmanager
def set_default_torch_num_threads(num_threads: int):
"""Sets the default number of threads for PyTorch to the given value."""
def set_default_torch_num_threads(num_threads: int | None = None):
"""
Sets the default number of threads for PyTorch to the given value.
`None` means using the value of the environment variable `OMP_NUM_THREADS`
(or `1` if that is not available).
"""
if num_threads is None:
num_threads = 1
try:
num_threads = int(os.environ["OMP_NUM_THREADS"])
except KeyError:
logger.debug_once(
"OMP_NUM_THREADS is not set; defaulting Torch threads to %d.",
num_threads,
)
except ValueError:
logger.warning_once(
"OMP_NUM_THREADS is invalid; defaulting Torch threads to %d.",
num_threads,
)
old_num_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
yield
torch.set_num_threads(old_num_threads)
try:
yield
finally:
torch.set_num_threads(old_num_threads)
@contextlib.contextmanager

View File

@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import time
from collections.abc import Mapping
from typing import Any, Literal, cast
@@ -35,6 +34,7 @@ from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
from vllm.utils.torch_utils import set_default_torch_num_threads
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
from vllm.v1.structured_output.backend_guidance import (
@@ -68,6 +68,19 @@ class InputProcessor:
self.mm_registry = mm_registry
self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.mm_encoder_cache_size = None
if (
self.mm_registry.supports_multimodal_inputs(self.model_config)
and not self.model_config.skip_tokenizer_init
):
with set_default_torch_num_threads():
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_modality(self.model_config)
)
_, self.mm_encoder_cache_size = compute_mm_encoder_budget(
self.vllm_config.scheduler_config, max_tokens_by_modality
)
self.input_preprocessor = InputPreprocessor(
self.model_config,
@@ -534,15 +547,7 @@ class InputProcessor:
# 1. Tokenize text prompt, with LoRA request if one exists.
# 2. For multimodal models with a merged preprocessor, preprocess
# multimodal data and expand prompt token ids accordingly.
num_threads = int(os.environ.get("OMP_NUM_THREADS", "1"))
if "OMP_NUM_THREADS" not in os.environ:
logger.debug_once(
"OMP_NUM_THREADS is not set; defaulting Torch threads to %d for "
"input preprocessing.",
num_threads,
)
with set_request_id(request_id), set_default_torch_num_threads(num_threads):
with set_request_id(request_id), set_default_torch_num_threads():
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt,
tokenization_kwargs=tokenization_kwargs,
@@ -743,6 +748,25 @@ class InputProcessor:
f"model length of {max_prompt_len}. {suggestion}"
)
if (
prompt_type == "decoder"
and prompt_inputs["type"] == "multimodal"
and self.mm_encoder_cache_size is not None
):
decoder_mm_positions = prompt_inputs["mm_placeholders"]
for modality, mm_positions in decoder_mm_positions.items():
for mm_position in mm_positions:
embed_length = mm_position.get_num_embeds
if embed_length > self.mm_encoder_cache_size:
raise ValueError(
f"The {prompt_type} prompt contains a(n) {modality} item "
f"with length {embed_length}, which exceeds the "
f"pre-allocated encoder cache size "
f"{self.mm_encoder_cache_size}. Please reduce the input "
f"size or increase the encoder cache size "
f"by setting --limit-mm-per-prompt at startup."
)
def stat_mm_cache(self) -> MultiModalCacheStats | None:
return self.input_preprocessor.stat_mm_cache()