diff --git a/vllm/utils/mem_utils.py b/vllm/utils/mem_utils.py index bf6d78465..dd91400f2 100644 --- a/vllm/utils/mem_utils.py +++ b/vllm/utils/mem_utils.py @@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int: max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail - assert max_shared_mem > 0, "max_shared_mem can not be zero" + assert max_shared_mem > 0, "max_shared_mem cannot be zero" return int(max_shared_mem) @@ -154,12 +154,16 @@ class MemoryProfilingResult: non_kv_cache_memory: int = 0 torch_peak_increase: int = 0 non_torch_increase: int = 0 - weights_memory: float = 0 + weights_memory: int = 0 before_create: MemorySnapshot = field(default_factory=MemorySnapshot) - before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) - after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) profile_time: float = 0.0 + def __post_init__(self) -> None: + device = self.before_create.device_ + + self.before_profile = MemorySnapshot(device=device, auto_measure=False) + self.after_profile = MemorySnapshot(device=device, auto_measure=False) + def __repr__(self) -> str: return ( f"Memory profiling takes {self.profile_time:.2f} seconds. " @@ -175,9 +179,12 @@ class MemoryProfilingResult: @contextlib.contextmanager def memory_profiling( - baseline_snapshot: MemorySnapshot, weights_memory: int + baseline_snapshot: MemorySnapshot, + weights_memory: int = 0, ) -> Generator[MemoryProfilingResult, None, None]: - """Memory profiling context manager. + """ + Memory profiling context manager. + baseline_snapshot: the memory snapshot before the current vLLM instance. weights_memory: memory used by PyTorch when loading the model weights. Note that, before loading the model weights, we also initialize the device @@ -217,21 +224,24 @@ def memory_profiling( b. 2 GiB reserved for the peak activation tensors (category 2) c. 1 GiB used by non-torch components (category 3) - The memory used for loading weights (a.) is directly given from the argument `weights_memory`. + The memory used for loading weights (a.) is directly given from the + argument `weights_memory`. - The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` + during profiling gives (b.). - The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa + The increase of `non_torch_memory` from creating the current vLLM instance + until after profiling to get (c.). + """ gc.collect() torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_) - result = MemoryProfilingResult() - - result.before_create = baseline_snapshot - # the part of memory used for holding the model weights - result.weights_memory = weights_memory + result = MemoryProfilingResult( + before_create=baseline_snapshot, + # the part of memory used for holding the model weights + weights_memory=weights_memory, + ) result.before_profile.measure() @@ -252,4 +262,4 @@ def memory_profiling( peak_activation_memory = result.torch_peak_increase result.non_kv_cache_memory = ( non_torch_memory + peak_activation_memory + result.weights_memory - ) # noqa + )