[CI/Build] Add test decorator for minimum GPU memory (#8925)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
@@ -10,6 +11,10 @@ class CpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
|
||||
return pynvml.nvmlDeviceGetName(handle)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_physical_device_total_memory(device_id: int = 0) -> int:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def warn_if_different_devices():
|
||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||
@@ -107,6 +114,11 @@ class CudaPlatform(Platform):
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_total_memory(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
|
||||
@@ -85,6 +85,12 @@ class Platform:
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -29,3 +29,8 @@ class RocmPlatform(Platform):
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@@ -10,6 +10,10 @@ class TpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@@ -8,13 +8,15 @@ class XPUPlatform(Platform):
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
return DeviceCapability(major=int(
|
||||
torch.xpu.get_device_capability(device_id)['version'].split('.')
|
||||
[0]),
|
||||
minor=int(
|
||||
torch.xpu.get_device_capability(device_id)
|
||||
['version'].split('.')[1]))
|
||||
major, minor, *_ = torch.xpu.get_device_capability(
|
||||
device_id)['version'].split('.')
|
||||
return DeviceCapability(major=int(major), minor=int(minor))
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
Reference in New Issue
Block a user