[Misc] Make mem utils can be reused by other platforms (#32322)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -589,8 +589,12 @@ class Platform:
|
|||||||
def __getattr__(self, key: str):
|
def __getattr__(self, key: str):
|
||||||
device = getattr(torch, self.device_type, None)
|
device = getattr(torch, self.device_type, None)
|
||||||
if device is not None and hasattr(device, key):
|
if device is not None and hasattr(device, key):
|
||||||
return getattr(device, key)
|
attr = getattr(device, key)
|
||||||
else:
|
# NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
|
||||||
|
# but the value of this attr could be `None`.
|
||||||
|
if attr is not None:
|
||||||
|
return attr
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Current platform %s does not have '%s' attribute.",
|
"Current platform %s does not have '%s' attribute.",
|
||||||
self.device_type,
|
self.device_type,
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import torch.types
|
import torch.types
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .mem_constants import GiB_bytes, MiB_bytes
|
from .mem_constants import GiB_bytes, MiB_bytes
|
||||||
|
|
||||||
|
|
||||||
@@ -45,8 +47,6 @@ class DeviceMemoryProfiler:
|
|||||||
|
|
||||||
def current_memory_usage(self) -> float:
|
def current_memory_usage(self) -> float:
|
||||||
# Return the memory usage in bytes.
|
# Return the memory usage in bytes.
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
return current_platform.get_current_memory_usage(self.device)
|
return current_platform.get_current_memory_usage(self.device)
|
||||||
|
|
||||||
@@ -80,8 +80,6 @@ class MemorySnapshot:
|
|||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
device_fn = current_platform.current_device
|
device_fn = current_platform.current_device
|
||||||
assert device_fn is not None
|
assert device_fn is not None
|
||||||
self.device_ = torch.device(device_fn())
|
self.device_ = torch.device(device_fn())
|
||||||
@@ -92,8 +90,6 @@ class MemorySnapshot:
|
|||||||
self.measure()
|
self.measure()
|
||||||
|
|
||||||
def measure(self) -> None:
|
def measure(self) -> None:
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
device = self.device_
|
device = self.device_
|
||||||
|
|
||||||
# we measure the torch peak memory usage via allocated_bytes,
|
# we measure the torch peak memory usage via allocated_bytes,
|
||||||
@@ -101,11 +97,11 @@ class MemorySnapshot:
|
|||||||
# After `torch.cuda.reset_peak_memory_stats()`,
|
# After `torch.cuda.reset_peak_memory_stats()`,
|
||||||
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
|
||||||
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
# when we call `torch.cuda.empty_cache()` or OOM happens.
|
||||||
self.torch_peak = torch.cuda.memory_stats(device).get(
|
self.torch_peak = current_platform.memory_stats(device).get(
|
||||||
"allocated_bytes.all.peak", 0
|
"allocated_bytes.all.peak", 0
|
||||||
)
|
)
|
||||||
|
|
||||||
self.free_memory, self.total_memory = torch.cuda.mem_get_info(device)
|
self.free_memory, self.total_memory = current_platform.mem_get_info(device)
|
||||||
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
|
||||||
if (
|
if (
|
||||||
current_platform.is_cuda()
|
current_platform.is_cuda()
|
||||||
@@ -130,7 +126,7 @@ class MemorySnapshot:
|
|||||||
# torch.cuda.memory_reserved() is how many bytes
|
# torch.cuda.memory_reserved() is how many bytes
|
||||||
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
|
||||||
# this is used to measure the non-torch memory usage
|
# this is used to measure the non-torch memory usage
|
||||||
self.torch_memory = torch.cuda.memory_reserved(device)
|
self.torch_memory = current_platform.memory_reserved(device)
|
||||||
|
|
||||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||||
self.timestamp = time.time()
|
self.timestamp = time.time()
|
||||||
@@ -159,7 +155,7 @@ class MemorySnapshot:
|
|||||||
f"torch_peak={format_gib(self.torch_peak)}GiB, "
|
f"torch_peak={format_gib(self.torch_peak)}GiB, "
|
||||||
f"free_memory={format_gib(self.free_memory)}GiB, "
|
f"free_memory={format_gib(self.free_memory)}GiB, "
|
||||||
f"total_memory={format_gib(self.total_memory)}GiB, "
|
f"total_memory={format_gib(self.total_memory)}GiB, "
|
||||||
f"cuda_memory={format_gib(self.cuda_memory)}GiB, "
|
f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, "
|
||||||
f"torch_memory={format_gib(self.torch_memory)}GiB, "
|
f"torch_memory={format_gib(self.torch_memory)}GiB, "
|
||||||
f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
|
f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
|
||||||
f"timestamp={self.timestamp}, "
|
f"timestamp={self.timestamp}, "
|
||||||
@@ -254,8 +250,8 @@ def memory_profiling(
|
|||||||
until after profiling to get (c.).
|
until after profiling to get (c.).
|
||||||
"""
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
current_platform.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
|
current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||||
|
|
||||||
result = MemoryProfilingResult(
|
result = MemoryProfilingResult(
|
||||||
before_create=baseline_snapshot,
|
before_create=baseline_snapshot,
|
||||||
@@ -268,7 +264,7 @@ def memory_profiling(
|
|||||||
yield result
|
yield result
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
current_platform.empty_cache()
|
||||||
|
|
||||||
result.after_profile.measure()
|
result.after_profile.measure()
|
||||||
|
|
||||||
|
|||||||
@@ -312,9 +312,6 @@ class Worker(WorkerBase):
|
|||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
return kv_cache_memory_bytes
|
return kv_cache_memory_bytes
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
|
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
with memory_profiling(
|
with memory_profiling(
|
||||||
@@ -360,7 +357,6 @@ class Worker(WorkerBase):
|
|||||||
format_gib(self.available_kv_cache_memory_bytes),
|
format_gib(self.available_kv_cache_memory_bytes),
|
||||||
scope="local",
|
scope="local",
|
||||||
)
|
)
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return int(self.available_kv_cache_memory_bytes)
|
return int(self.available_kv_cache_memory_bytes)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user