[Hardware][ROCM] using current_platform.is_rocm (#9642)

Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
wangshuai09
2024-10-28 12:07:00 +08:00
committed by GitHub
parent 34a9941620
commit 4e2d95e372
32 changed files with 165 additions and 151 deletions

View File

@@ -314,10 +314,6 @@ class PyObjectCache:
self._index = 0
def is_hip() -> bool:
return torch.version.hip is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
if not torch.cuda._is_compiled():
return 0
if is_hip():
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(