[Perf] avoid duplicate mem_get_info() call in get_current_memory_usage (#33064)

Signed-off-by: Paco Xu <paco.xu@daocloud.io>
This commit is contained in:
Paco Xu
2026-01-27 11:45:45 +08:00
committed by GitHub
parent 0b53bec60b
commit 157caf511b

View File

@@ -561,7 +561,8 @@ class RocmPlatform(Platform):
cls, device: torch.types.Device | None = None
) -> float:
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0]
free_mem, total_mem = torch.cuda.mem_get_info(device)
return total_mem - free_mem
@classmethod
def get_device_communicator_cls(cls) -> str: