diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 6e67456bf..49225fc2e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -561,7 +561,8 @@ class RocmPlatform(Platform): cls, device: torch.types.Device | None = None ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(device)[0] + free_mem, total_mem = torch.cuda.mem_get_info(device) + return total_mem - free_mem @classmethod def get_device_communicator_cls(cls) -> str: