[Hardware] Replace memory related torch.cuda APIs (#37031)
Signed-off-by: Kunshang Ji <jikunshang95@gmail.com>
This commit is contained in:
@@ -418,8 +418,8 @@ def _run_single_benchmark(
|
||||
mem_stats = {}
|
||||
if config.profile_memory:
|
||||
mem_stats = {
|
||||
"allocated_mb": torch.cuda.memory_allocated(device) / 1024**2,
|
||||
"reserved_mb": torch.cuda.memory_reserved(device) / 1024**2,
|
||||
"allocated_mb": torch.accelerator.memory_allocated(device) / 1024**2,
|
||||
"reserved_mb": torch.accelerator.memory_reserved(device) / 1024**2,
|
||||
}
|
||||
|
||||
return times, mem_stats
|
||||
|
||||
@@ -95,13 +95,16 @@ def create_logits(
|
||||
def measure_memory() -> tuple[int, int]:
|
||||
"""Return (allocated, reserved) memory in bytes."""
|
||||
torch.accelerator.synchronize()
|
||||
return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated()
|
||||
return (
|
||||
torch.accelerator.memory_allocated(),
|
||||
torch.accelerator.max_memory_allocated(),
|
||||
)
|
||||
|
||||
|
||||
def reset_memory_stats():
|
||||
"""Reset peak memory statistics."""
|
||||
reset_buffer_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.accelerator.reset_peak_memory_stats()
|
||||
torch.accelerator.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def test_gc():
|
||||
# The memory allocated for model and KV cache should be released.
|
||||
# The memory allocated for PyTorch and others should be less than 50MB.
|
||||
# Usually, it's around 10MB.
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
allocated = torch.accelerator.memory_allocated()
|
||||
assert allocated < 50 * 1024 * 1024
|
||||
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ def test_memory_profiling():
|
||||
def measure_current_non_torch():
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
current_used = total - free
|
||||
current_torch = torch.cuda.memory_reserved()
|
||||
current_torch = torch.accelerator.memory_reserved()
|
||||
current_non_torch = current_used - current_torch
|
||||
return current_non_torch
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import regex as re
|
||||
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
|
||||
# --------------------------------------------------------------------------- #
|
||||
_TORCH_CUDA_PATTERNS = [
|
||||
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|set_device|device\()\b",
|
||||
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
|
||||
r"\bwith\storch\.cuda\.device\b",
|
||||
]
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ class BaseModelLoader(ABC):
|
||||
# Log peak GPU memory after loading weights. This is needed
|
||||
# to have test coverage on peak memory for online quantization.
|
||||
if current_platform.is_cuda():
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
peak_memory = torch.accelerator.max_memory_allocated()
|
||||
logger.debug_once(
|
||||
"Peak GPU memory after loading weights: %s GiB",
|
||||
format_gib(peak_memory),
|
||||
|
||||
@@ -93,11 +93,11 @@ class MemorySnapshot:
|
||||
device = self.device_
|
||||
|
||||
# we measure the torch peak memory usage via allocated_bytes,
|
||||
# rather than `torch.cuda.memory_reserved()` .
|
||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||
# rather than `torch.accelerator.memory_reserved()` .
|
||||
# After `torch.accelerator.reset_peak_memory_stats()`,
|
||||
# `torch.accelerator.memory_reserved()` will keep growing, and only shrink
|
||||
# when we call `torch.accelerator.empty_cache()` or OOM happens.
|
||||
self.torch_peak = current_platform.memory_stats(device).get(
|
||||
self.torch_peak = torch.accelerator.memory_stats(device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
@@ -123,10 +123,10 @@ class MemorySnapshot:
|
||||
|
||||
self.cuda_memory = self.total_memory - self.free_memory
|
||||
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# torch.accelerator.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
# this is used to measure the non-torch memory usage
|
||||
self.torch_memory = current_platform.memory_reserved(device)
|
||||
self.torch_memory = torch.accelerator.memory_reserved(device)
|
||||
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
@@ -243,7 +243,7 @@ def memory_profiling(
|
||||
The memory used for loading weights (a.) is directly given from the
|
||||
argument `weights_memory`.
|
||||
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
|
||||
The increase of `torch.accelerator.memory_stats()["allocated_bytes.all.peak"]`
|
||||
during profiling gives (b.).
|
||||
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance
|
||||
@@ -251,7 +251,7 @@ def memory_profiling(
|
||||
"""
|
||||
gc.collect()
|
||||
torch.accelerator.empty_cache()
|
||||
current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
torch.accelerator.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
|
||||
result = MemoryProfilingResult(
|
||||
before_create=baseline_snapshot,
|
||||
|
||||
@@ -387,7 +387,7 @@ class Worker(WorkerBase):
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
profile_torch_peak = current_platform.memory_stats(self.device).get(
|
||||
profile_torch_peak = torch.accelerator.memory_stats(self.device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user