[Misc] Add a wrapper for torch.inference_mode (#6618)
This commit is contained in:
@@ -2,7 +2,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Optional[Platform]
|
||||
|
||||
@@ -12,7 +14,10 @@ if torch.version.cuda is not None:
|
||||
elif torch.version.hip is not None:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
elif is_tpu():
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
else:
|
||||
current_platform = None
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
||||
|
||||
Reference in New Issue
Block a user