[Bugfix] Fix multi nodes TP+PP for XPU (#8884)

Signed-off-by: YiSheng5 <syhm@mail.ustc.edu.cn>
Signed-off-by: yan ma <yan.ma@intel.com>
Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn>
This commit is contained in:
Yan Ma
2024-10-30 12:34:45 +08:00
committed by GitHub
parent 62fac4b9aa
commit 04a3ae0aca
7 changed files with 63 additions and 11 deletions

View File

@@ -45,6 +45,9 @@ except Exception:
is_xpu = False
try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True

View File

@@ -20,3 +20,7 @@ class XPUPlatform(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory
@staticmethod
def inference_mode():
return torch.no_grad()