[cuda] manually import the correct pynvml module (#12679)
fixes problems like https://github.com/vllm-project/vllm/pull/12635 and https://github.com/vllm-project/vllm/pull/12636 and https://github.com/vllm-project/vllm/pull/12565 --------- Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -33,7 +33,8 @@ def cuda_platform_plugin() -> Optional[str]:
|
|||||||
is_cuda = False
|
is_cuda = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import pynvml
|
from vllm.utils import import_pynvml
|
||||||
|
pynvml = import_pynvml()
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
try:
|
try:
|
||||||
if pynvml.nvmlDeviceGetCount() > 0:
|
if pynvml.nvmlDeviceGetCount() > 0:
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ from functools import lru_cache, wraps
|
|||||||
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
import pynvml
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
@@ -16,6 +15,7 @@ from typing_extensions import ParamSpec
|
|||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import import_pynvml
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
@@ -29,13 +29,7 @@ logger = init_logger(__name__)
|
|||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
_R = TypeVar("_R")
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
if pynvml.__file__.endswith("__init__.py"):
|
pynvml = import_pynvml()
|
||||||
logger.warning(
|
|
||||||
"You are using a deprecated `pynvml` package. Please install"
|
|
||||||
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
|
|
||||||
" When both of them are installed, `pynvml` will take precedence"
|
|
||||||
" and cause errors. See https://pypi.org/project/pynvml "
|
|
||||||
"for more information.")
|
|
||||||
|
|
||||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||||
|
|||||||
@@ -2208,3 +2208,55 @@ def run_method(obj: Any, method: Union[str, bytes, Callable], args: Tuple[Any],
|
|||||||
else:
|
else:
|
||||||
func = partial(method, obj) # type: ignore
|
func = partial(method, obj) # type: ignore
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def import_pynvml():
|
||||||
|
"""
|
||||||
|
Historical comments:
|
||||||
|
|
||||||
|
libnvml.so is the library behind nvidia-smi, and
|
||||||
|
pynvml is a Python wrapper around it. We use it to get GPU
|
||||||
|
status without initializing CUDA context in the current process.
|
||||||
|
Historically, there are two packages that provide pynvml:
|
||||||
|
- `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official
|
||||||
|
wrapper. It is a dependency of vLLM, and is installed when users
|
||||||
|
install vLLM. It provides a Python module named `pynvml`.
|
||||||
|
- `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper.
|
||||||
|
Prior to version 12.0, it also provides a Python module `pynvml`,
|
||||||
|
and therefore conflicts with the official one. What's worse,
|
||||||
|
the module is a Python package, and has higher priority than
|
||||||
|
the official one which is a standalone Python file.
|
||||||
|
This causes errors when both of them are installed.
|
||||||
|
Starting from version 12.0, it migrates to a new module
|
||||||
|
named `pynvml_utils` to avoid the conflict.
|
||||||
|
|
||||||
|
TL;DR: if users have pynvml<12.0 installed, it will cause problems.
|
||||||
|
Otherwise, `import pynvml` will import the correct module.
|
||||||
|
We take the safest approach here, to manually import the correct
|
||||||
|
`pynvml.py` module from the `nvidia-ml-py` package.
|
||||||
|
"""
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import pynvml
|
||||||
|
return pynvml
|
||||||
|
if "pynvml" in sys.modules:
|
||||||
|
import pynvml
|
||||||
|
if pynvml.__file__.endswith("__init__.py"):
|
||||||
|
# this is pynvml < 12.0
|
||||||
|
raise RuntimeError(
|
||||||
|
"You are using a deprecated `pynvml` package. "
|
||||||
|
"Please uninstall `pynvml` or upgrade to at least"
|
||||||
|
" version 12.0. See https://pypi.org/project/pynvml "
|
||||||
|
"for more information.")
|
||||||
|
return sys.modules["pynvml"]
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import site
|
||||||
|
for site_dir in site.getsitepackages():
|
||||||
|
pynvml_path = os.path.join(site_dir, "pynvml.py")
|
||||||
|
if os.path.exists(pynvml_path):
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"pynvml", pynvml_path)
|
||||||
|
pynvml = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules["pynvml"] = pynvml
|
||||||
|
spec.loader.exec_module(pynvml)
|
||||||
|
return pynvml
|
||||||
|
|||||||
Reference in New Issue
Block a user