[Bugfix][ROCm] Using device_type because on ROCm the API is still torch.cuda (#17601)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-05-03 01:25:47 -04:00
committed by GitHub
parent c8386fa61d
commit a92842454c

View File

@@ -406,12 +406,12 @@ class Platform:
"""Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
device = getattr(torch, self.device_type, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s does not have '%s'" \
" attribute.", self.device_name, key)
" attribute.", self.device_type, key)
return None
@classmethod