[Misc] Upgrade to pytorch 2.5 (#9588)
Signed-off-by: Bill Nell <bill@neuralmagic.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from functools import lru_cache, wraps
|
||||
from typing import Callable, List, Tuple, TypeVar
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@@ -26,6 +27,10 @@ if pynvml.__file__.endswith("__init__.py"):
|
||||
" and cause errors. See https://pypi.org/project/pynvml "
|
||||
"for more information.")
|
||||
|
||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
|
||||
Reference in New Issue
Block a user