[hardware] unify usage of is_tpu to current_platform.is_tpu() (#7102)
This commit is contained in:
@@ -1,22 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Optional[Platform]
|
||||
current_platform: Platform
|
||||
|
||||
if torch.version.cuda is not None:
|
||||
try:
|
||||
import libtpu
|
||||
except ImportError:
|
||||
libtpu = None
|
||||
|
||||
if libtpu is not None:
|
||||
# people might install pytorch built with cuda but run on tpu
|
||||
# so we need to check tpu first
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
elif torch.version.cuda is not None:
|
||||
from .cuda import CudaPlatform
|
||||
current_platform = CudaPlatform()
|
||||
elif torch.version.hip is not None:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
elif is_tpu():
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
else:
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user