[Bugfix] Fix multi nodes TP+PP for XPU (#8884)
Signed-off-by: YiSheng5 <syhm@mail.ustc.edu.cn> Signed-off-by: yan ma <yan.ma@intel.com> Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn>
This commit is contained in:
@@ -20,3 +20,7 @@ class XPUPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user