[CI/Build] Add test decorator for minimum GPU memory (#8925)

This commit is contained in:
Cyrus Leung
2024-09-29 10:50:51 +08:00
committed by GitHub
parent d081da0064
commit 26a68d5d7e
14 changed files with 117 additions and 73 deletions

View File

@@ -1,3 +1,4 @@
import psutil
import torch
from .interface import Platform, PlatformEnum
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
return pynvml.nvmlDeviceGetName(handle)
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
@@ -107,6 +114,11 @@ class CudaPlatform(Platform):
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_total_memory(physical_device_id)
@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:

View File

@@ -85,6 +85,12 @@ class Platform:
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod

View File

@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory

View File

@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@@ -8,13 +8,15 @@ class XPUPlatform(Platform):
@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
return DeviceCapability(major=int(
torch.xpu.get_device_capability(device_id)['version'].split('.')
[0]),
minor=int(
torch.xpu.get_device_capability(device_id)
['version'].split('.')[1]))
major, minor, *_ = torch.xpu.get_device_capability(
device_id)['version'].split('.')
return DeviceCapability(major=int(major), minor=int(minor))
@staticmethod
def get_device_name(device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory