[Misc] Make mem utils can be reused by other platforms (#32322)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -589,8 +589,12 @@ class Platform:
|
||||
def __getattr__(self, key: str):
|
||||
device = getattr(torch, self.device_type, None)
|
||||
if device is not None and hasattr(device, key):
|
||||
return getattr(device, key)
|
||||
else:
|
||||
attr = getattr(device, key)
|
||||
# NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
|
||||
# but the value of this attr could be `None`.
|
||||
if attr is not None:
|
||||
return attr
|
||||
|
||||
logger.warning(
|
||||
"Current platform %s does not have '%s' attribute.",
|
||||
self.device_type,
|
||||
|
||||
@@ -11,6 +11,8 @@ import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .mem_constants import GiB_bytes, MiB_bytes
|
||||
|
||||
|
||||
@@ -45,8 +47,6 @@ class DeviceMemoryProfiler:
|
||||
|
||||
def current_memory_usage(self) -> float:
|
||||
# Return the memory usage in bytes.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
gc.collect()
|
||||
return current_platform.get_current_memory_usage(self.device)
|
||||
|
||||
@@ -80,8 +80,6 @@ class MemorySnapshot:
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.device is None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device_fn = current_platform.current_device
|
||||
assert device_fn is not None
|
||||
self.device_ = torch.device(device_fn())
|
||||
@@ -92,8 +90,6 @@ class MemorySnapshot:
|
||||
self.measure()
|
||||
|
||||
def measure(self) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = self.device_
|
||||
|
||||
# we measure the torch peak memory usage via allocated_bytes,
|
||||
@@ -101,11 +97,11 @@ class MemorySnapshot:
|
||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||
self.torch_peak = torch.cuda.memory_stats(device).get(
|
||||
self.torch_peak = current_platform.memory_stats(device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
self.free_memory, self.total_memory = torch.cuda.mem_get_info(device)
|
||||
self.free_memory, self.total_memory = current_platform.mem_get_info(device)
|
||||
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
@@ -130,7 +126,7 @@ class MemorySnapshot:
|
||||
# torch.cuda.memory_reserved() is how many bytes
|
||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||
# this is used to measure the non-torch memory usage
|
||||
self.torch_memory = torch.cuda.memory_reserved(device)
|
||||
self.torch_memory = current_platform.memory_reserved(device)
|
||||
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
@@ -159,7 +155,7 @@ class MemorySnapshot:
|
||||
f"torch_peak={format_gib(self.torch_peak)}GiB, "
|
||||
f"free_memory={format_gib(self.free_memory)}GiB, "
|
||||
f"total_memory={format_gib(self.total_memory)}GiB, "
|
||||
f"cuda_memory={format_gib(self.cuda_memory)}GiB, "
|
||||
f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, "
|
||||
f"torch_memory={format_gib(self.torch_memory)}GiB, "
|
||||
f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
|
||||
f"timestamp={self.timestamp}, "
|
||||
@@ -254,8 +250,8 @@ def memory_profiling(
|
||||
until after profiling to get (c.).
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
current_platform.empty_cache()
|
||||
current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
|
||||
result = MemoryProfilingResult(
|
||||
before_create=baseline_snapshot,
|
||||
@@ -268,7 +264,7 @@ def memory_profiling(
|
||||
yield result
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
current_platform.empty_cache()
|
||||
|
||||
result.after_profile.measure()
|
||||
|
||||
|
||||
@@ -312,9 +312,6 @@ class Worker(WorkerBase):
|
||||
logger.info(msg)
|
||||
return kv_cache_memory_bytes
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
@@ -360,7 +357,6 @@ class Worker(WorkerBase):
|
||||
format_gib(self.available_kv_cache_memory_bytes),
|
||||
scope="local",
|
||||
)
|
||||
gc.collect()
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user