[misc] move functions to config.py (#10624)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-25 01:27:30 -08:00
committed by GitHub
parent 25d806e953
commit 05d1f8c9c6
11 changed files with 62 additions and 73 deletions

View File

@@ -1,15 +1,10 @@
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch
import vllm.envs as envs
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = logging.getLogger(__name__)
# make sure one process only loads plugins once
@@ -64,54 +59,3 @@ def load_general_plugins():
logger.info("plugin %s loaded.", plugin.name)
except Exception:
logger.exception("Failed to load plugin %s", plugin.name)
_current_vllm_config: Optional["VllmConfig"] = None
@contextmanager
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config