[Misc] Add a wrapper for torch.inference_mode (#6618)
This commit is contained in:
@@ -2,7 +2,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Optional[Platform]
|
||||
|
||||
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
|
||||
elif torch.version.hip is not None:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
elif is_tpu():
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
else:
|
||||
current_platform = None
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
TPU = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
@@ -16,6 +20,23 @@ class Platform:
|
||||
def is_rocm(self) -> bool:
|
||||
return self._enum == PlatformEnum.ROCM
|
||||
|
||||
def is_tpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.TPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
This wrapper is recommended because some hardware backends such as TPU
|
||||
do not support `torch.inference_mode`. In such a case, they will fall
|
||||
back to `torch.no_grad` by overriding this method.
|
||||
"""
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
17
vllm/platforms/tpu.py
Normal file
17
vllm/platforms/tpu.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise RuntimeError("TPU does not have device capability.")
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
Reference in New Issue
Block a user