[Misc] Make mem utils can be reused by other platforms (#32322)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-01-14 19:46:01 +08:00
committed by GitHub
parent 3f28174c6a
commit ce0946249d
3 changed files with 21 additions and 25 deletions

View File

@@ -589,8 +589,12 @@ class Platform:
def __getattr__(self, key: str): def __getattr__(self, key: str):
device = getattr(torch, self.device_type, None) device = getattr(torch, self.device_type, None)
if device is not None and hasattr(device, key): if device is not None and hasattr(device, key):
return getattr(device, key) attr = getattr(device, key)
else: # NOTE: `hasattr(device, key)=True` can only avoid AttributeError,
# but the value of this attr could be `None`.
if attr is not None:
return attr
logger.warning( logger.warning(
"Current platform %s does not have '%s' attribute.", "Current platform %s does not have '%s' attribute.",
self.device_type, self.device_type,

View File

@@ -11,6 +11,8 @@ import psutil
import torch import torch
import torch.types import torch.types
from vllm.platforms import current_platform
from .mem_constants import GiB_bytes, MiB_bytes from .mem_constants import GiB_bytes, MiB_bytes
@@ -45,8 +47,6 @@ class DeviceMemoryProfiler:
def current_memory_usage(self) -> float: def current_memory_usage(self) -> float:
# Return the memory usage in bytes. # Return the memory usage in bytes.
from vllm.platforms import current_platform
gc.collect() gc.collect()
return current_platform.get_current_memory_usage(self.device) return current_platform.get_current_memory_usage(self.device)
@@ -80,8 +80,6 @@ class MemorySnapshot:
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.device is None: if self.device is None:
from vllm.platforms import current_platform
device_fn = current_platform.current_device device_fn = current_platform.current_device
assert device_fn is not None assert device_fn is not None
self.device_ = torch.device(device_fn()) self.device_ = torch.device(device_fn())
@@ -92,8 +90,6 @@ class MemorySnapshot:
self.measure() self.measure()
def measure(self) -> None: def measure(self) -> None:
from vllm.platforms import current_platform
device = self.device_ device = self.device_
# we measure the torch peak memory usage via allocated_bytes, # we measure the torch peak memory usage via allocated_bytes,
@@ -101,11 +97,11 @@ class MemorySnapshot:
# After `torch.cuda.reset_peak_memory_stats()`, # After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink # `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens. # when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats(device).get( self.torch_peak = current_platform.memory_stats(device).get(
"allocated_bytes.all.peak", 0 "allocated_bytes.all.peak", 0
) )
self.free_memory, self.total_memory = torch.cuda.mem_get_info(device) self.free_memory, self.total_memory = current_platform.mem_get_info(device)
shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark
if ( if (
current_platform.is_cuda() current_platform.is_cuda()
@@ -130,7 +126,7 @@ class MemorySnapshot:
# torch.cuda.memory_reserved() is how many bytes # torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.) # PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage # this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved(device) self.torch_memory = current_platform.memory_reserved(device)
self.non_torch_memory = self.cuda_memory - self.torch_memory self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time() self.timestamp = time.time()
@@ -159,7 +155,7 @@ class MemorySnapshot:
f"torch_peak={format_gib(self.torch_peak)}GiB, " f"torch_peak={format_gib(self.torch_peak)}GiB, "
f"free_memory={format_gib(self.free_memory)}GiB, " f"free_memory={format_gib(self.free_memory)}GiB, "
f"total_memory={format_gib(self.total_memory)}GiB, " f"total_memory={format_gib(self.total_memory)}GiB, "
f"cuda_memory={format_gib(self.cuda_memory)}GiB, " f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, "
f"torch_memory={format_gib(self.torch_memory)}GiB, " f"torch_memory={format_gib(self.torch_memory)}GiB, "
f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, " f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, "
f"timestamp={self.timestamp}, " f"timestamp={self.timestamp}, "
@@ -254,8 +250,8 @@ def memory_profiling(
until after profiling to get (c.). until after profiling to get (c.).
""" """
gc.collect() gc.collect()
torch.cuda.empty_cache() current_platform.empty_cache()
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_) current_platform.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult( result = MemoryProfilingResult(
before_create=baseline_snapshot, before_create=baseline_snapshot,
@@ -268,7 +264,7 @@ def memory_profiling(
yield result yield result
gc.collect() gc.collect()
torch.cuda.empty_cache() current_platform.empty_cache()
result.after_profile.measure() result.after_profile.measure()

View File

@@ -312,9 +312,6 @@ class Worker(WorkerBase):
logger.info(msg) logger.info(msg)
return kv_cache_memory_bytes return kv_cache_memory_bytes
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Execute a forward pass with dummy inputs to profile the memory usage # Execute a forward pass with dummy inputs to profile the memory usage
# of the model. # of the model.
with memory_profiling( with memory_profiling(
@@ -360,7 +357,6 @@ class Worker(WorkerBase):
format_gib(self.available_kv_cache_memory_bytes), format_gib(self.available_kv_cache_memory_bytes),
scope="local", scope="local",
) )
gc.collect()
return int(self.available_kv_cache_memory_bytes) return int(self.available_kv_cache_memory_bytes)