[Platform] Move async output check to platform (#10768)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2024-12-10 01:24:46 +08:00
committed by GitHub
parent e691b26f6f
commit aea2fc38c3
10 changed files with 66 additions and 22 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
@@ -20,6 +20,10 @@ class HpuPlatform(Platform):
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@staticmethod
def inference_mode():
return torch.no_grad()