[Misc] Upgrade to pytorch 2.5 (#9588)

Signed-off-by: Bill Nell <bill@neuralmagic.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
bnellnm
2024-10-27 05:44:24 -04:00
committed by GitHub
parent 8549c82660
commit 3cb07a36a2
8 changed files with 47 additions and 24 deletions

View File

@@ -7,6 +7,7 @@ from functools import lru_cache, wraps
from typing import Callable, List, Tuple, TypeVar
import pynvml
import torch
from typing_extensions import ParamSpec
from vllm.logger import init_logger
@@ -26,6 +27,10 @@ if pynvml.__file__.endswith("__init__.py"):
" and cause errors. See https://pypi.org/project/pynvml "
"for more information.")
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
# see https://github.com/huggingface/diffusers/issues/9704 for details
torch.backends.cuda.enable_cudnn_sdp(False)
# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
# all the related functions work on real physical device ids.