[doc] fix doc build error caused by msgspec (#7659)

This commit is contained in:
youkaichao
2024-08-19 17:50:59 -07:00
committed by GitHub
parent 67e02fa8a4
commit e54ebc2f8f
2 changed files with 41 additions and 9 deletions

View File

@@ -1,23 +1,54 @@
import torch
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Platform
try:
import libtpu
except ImportError:
libtpu = None
# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because
# they only indicate the build configuration, not the runtime environment.
# For example, people can install a cuda build of pytorch but run on tpu.
if libtpu is not None:
is_tpu = False
try:
import torch_xla.core.xla_model as xm
xm.xla_device(devkind="TPU")
is_tpu = True
except Exception:
pass
is_cuda = False
try:
import pynvml
pynvml.nvmlInit()
try:
if pynvml.nvmlDeviceGetCount() > 0:
is_cuda = True
finally:
pynvml.nvmlShutdown()
except Exception:
pass
is_rocm = False
try:
import amdsmi
amdsmi.amdsmi_init()
try:
if len(amdsmi.amdsmi_get_processor_handles()) > 0:
is_rocm = True
finally:
amdsmi.amdsmi_shut_down()
except Exception:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
elif is_cuda:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
elif is_rocm:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
else: