[Misc] Add a wrapper for torch.inference_mode (#6618)

This commit is contained in:
Woosuk Kwon
2024-07-21 18:43:11 -07:00
committed by GitHub
parent c9eef37f32
commit 42de2cefcb
5 changed files with 49 additions and 4 deletions

View File

@@ -2,7 +2,9 @@ from typing import Optional
import torch
from .interface import Platform, PlatformEnum
from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform]
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = None
current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform']