[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)

This commit is contained in:
youkaichao
2024-08-13 00:16:42 -07:00
committed by GitHub
parent 7025b11d94
commit 4d2dc5072b
8 changed files with 29 additions and 33 deletions

View File

@@ -1,22 +1,25 @@
from typing import Optional
import torch
from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform]
current_platform: Platform
if torch.version.cuda is not None:
try:
import libtpu
except ImportError:
libtpu = None
if libtpu is not None:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
from .tpu import TpuPlatform
current_platform = TpuPlatform()
elif torch.version.cuda is not None:
from .cuda import CudaPlatform
current_platform = CudaPlatform()
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = UnspecifiedPlatform()