[Platform] Move async output check to platform (#10768)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2024-12-10 01:24:46 +08:00
committed by GitHub
parent e691b26f6f
commit aea2fc38c3
10 changed files with 66 additions and 22 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
@@ -41,6 +41,10 @@ class XPUPlatform(Platform):
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
return torch.no_grad()