[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-19 11:22:26 +08:00
committed by GitHub
parent 7eb719df13
commit 8c1fb50705
14 changed files with 99 additions and 69 deletions

View File

@@ -1,11 +1,15 @@
import torch
from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend
class HpuPlatform(Platform):
_enum = PlatformEnum.HPU
@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN
@staticmethod
def inference_mode():
return torch.no_grad()