[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -314,10 +314,6 @@ class PyObjectCache:
|
||||
self._index = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
@@ -1098,7 +1094,7 @@ def _cuda_device_count_stateless(
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if is_hip():
|
||||
if current_platform.is_rocm():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
|
||||
|
||||
Reference in New Issue
Block a user