[Hardware] Replace memory related torch.cuda APIs (#37031)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -64,7 +64,7 @@ class BaseModelLoader(ABC):
|
||||
# Log peak GPU memory after loading weights. This is needed
|
||||
# to have test coverage on peak memory for online quantization.
|
||||
if current_platform.is_cuda():
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
peak_memory = torch.accelerator.max_memory_allocated()
|
||||
logger.debug_once(
|
||||
"Peak GPU memory after loading weights: %s GiB",
|
||||
format_gib(peak_memory),
|
||||
|
||||
@@ -93,11 +93,11 @@ class MemorySnapshot:
|
||||
device = self.device_
|
||||
|
||||
# we measure the torch peak memory usage via allocated_bytes,
|
||||
# rather than `torch.cuda.memory_reserved()` .
|
||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||
# rather than `torch.accelerator.memory_reserved()` .
|
||||
# After `torch.accelerator.reset_peak_memory_stats()`,
|
||||
# `torch.accelerator.memory_reserved()` will keep growing, and only shrink
|
||||
# when we call `torch.accelerator.empty_cache()` or OOM happens.
|
||||
self.torch_peak = current_platform.memory_stats(device).get(
|
||||
self.torch_peak = torch.accelerator.memory_stats(device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
@@ -123,10 +123,10 @@ class MemorySnapshot:
|
||||
|
||||
self.cuda_memory = self.total_memory - self.free_memory
|
||||
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# torch.accelerator.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
# this is used to measure the non-torch memory usage
|
||||
self.torch_memory = current_platform.memory_reserved(device)
|
||||
self.torch_memory = torch.accelerator.memory_reserved(device)
|
||||
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
@@ -243,7 +243,7 @@ def memory_profiling(
|
||||
The memory used for loading weights (a.) is directly given from the
|
||||
argument `weights_memory`.
|
||||
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
|
||||
The increase of `torch.accelerator.memory_stats()["allocated_bytes.all.peak"]`
|
||||
during profiling gives (b.).
|
||||
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance
|
||||
@@ -251,7 +251,7 @@ def memory_profiling(
|
||||
"""
|
||||
gc.collect()
|
||||
torch.accelerator.empty_cache()
|
||||
current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
torch.accelerator.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
|
||||
result = MemoryProfilingResult(
|
||||
before_create=baseline_snapshot,
|
||||
|
||||
@@ -387,7 +387,7 @@ class Worker(WorkerBase):
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
profile_torch_peak = current_platform.memory_stats(self.device).get(
|
||||
profile_torch_peak = torch.accelerator.memory_stats(self.device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user