diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f86abd712..8d36d6d52 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -589,14 +589,18 @@ class Platform: def __getattr__(self, key: str): device = getattr(torch, self.device_type, None) if device is not None and hasattr(device, key): - return getattr(device, key) - else: - logger.warning( - "Current platform %s does not have '%s' attribute.", - self.device_type, - key, - ) - return None + attr = getattr(device, key) + # NOTE: `hasattr(device, key)=True` can only avoid AttributeError, + # but the value of this attr could be `None`. + if attr is not None: + return attr + + logger.warning( + "Current platform %s does not have '%s' attribute.", + self.device_type, + key, + ) + return None def get_global_graph_pool(self) -> Any: """ diff --git a/vllm/utils/mem_utils.py b/vllm/utils/mem_utils.py index 12d1541ad..0b3971126 100644 --- a/vllm/utils/mem_utils.py +++ b/vllm/utils/mem_utils.py @@ -11,6 +11,8 @@ import psutil import torch import torch.types +from vllm.platforms import current_platform + from .mem_constants import GiB_bytes, MiB_bytes @@ -45,8 +47,6 @@ class DeviceMemoryProfiler: def current_memory_usage(self) -> float: # Return the memory usage in bytes. - from vllm.platforms import current_platform - gc.collect() return current_platform.get_current_memory_usage(self.device) @@ -80,8 +80,6 @@ class MemorySnapshot: def __post_init__(self) -> None: if self.device is None: - from vllm.platforms import current_platform - device_fn = current_platform.current_device assert device_fn is not None self.device_ = torch.device(device_fn()) @@ -92,8 +90,6 @@ class MemorySnapshot: self.measure() def measure(self) -> None: - from vllm.platforms import current_platform - device = self.device_ # we measure the torch peak memory usage via allocated_bytes, @@ -101,11 +97,11 @@ class MemorySnapshot: # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats(device).get( + self.torch_peak = current_platform.memory_stats(device).get( "allocated_bytes.all.peak", 0 ) - self.free_memory, self.total_memory = torch.cuda.mem_get_info(device) + self.free_memory, self.total_memory = current_platform.mem_get_info(device) shared_sysmem_device_mem_sms = ((8, 7), (11, 0), (12, 1)) # Orin, Thor, Spark if ( current_platform.is_cuda() @@ -130,7 +126,7 @@ class MemorySnapshot: # torch.cuda.memory_reserved() is how many bytes # PyTorch gets from cuda (by calling cudaMalloc, etc.) # this is used to measure the non-torch memory usage - self.torch_memory = torch.cuda.memory_reserved(device) + self.torch_memory = current_platform.memory_reserved(device) self.non_torch_memory = self.cuda_memory - self.torch_memory self.timestamp = time.time() @@ -159,7 +155,7 @@ class MemorySnapshot: f"torch_peak={format_gib(self.torch_peak)}GiB, " f"free_memory={format_gib(self.free_memory)}GiB, " f"total_memory={format_gib(self.total_memory)}GiB, " - f"cuda_memory={format_gib(self.cuda_memory)}GiB, " + f"{current_platform.device_name}_memory={format_gib(self.cuda_memory)}GiB, " f"torch_memory={format_gib(self.torch_memory)}GiB, " f"non_torch_memory={format_gib(self.non_torch_memory)}GiB, " f"timestamp={self.timestamp}, " @@ -254,8 +250,8 @@ def memory_profiling( until after profiling to get (c.). """ gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_) + current_platform.empty_cache() + current_platform.reset_peak_memory_stats(baseline_snapshot.device_) result = MemoryProfilingResult( before_create=baseline_snapshot, @@ -268,7 +264,7 @@ def memory_profiling( yield result gc.collect() - torch.cuda.empty_cache() + current_platform.empty_cache() result.after_profile.measure() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 562664524..013780479 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -312,9 +312,6 @@ class Worker(WorkerBase): logger.info(msg) return kv_cache_memory_bytes - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - # Execute a forward pass with dummy inputs to profile the memory usage # of the model. with memory_profiling( @@ -360,7 +357,6 @@ class Worker(WorkerBase): format_gib(self.available_kv_cache_memory_bytes), scope="local", ) - gc.collect() return int(self.available_kv_cache_memory_bytes)