[Chore] Cleanup mem_utils.py (#31793)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-06 19:55:59 +08:00
committed by GitHub
parent 6ebb66ccea
commit 14df02b4e1

View File

@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu) max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail # will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero" assert max_shared_mem > 0, "max_shared_mem cannot be zero"
return int(max_shared_mem) return int(max_shared_mem)
@@ -154,12 +154,16 @@ class MemoryProfilingResult:
non_kv_cache_memory: int = 0 non_kv_cache_memory: int = 0
torch_peak_increase: int = 0 torch_peak_increase: int = 0
non_torch_increase: int = 0 non_torch_increase: int = 0
weights_memory: float = 0 weights_memory: int = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot) before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0 profile_time: float = 0.0
def __post_init__(self) -> None:
device = self.before_create.device_
self.before_profile = MemorySnapshot(device=device, auto_measure=False)
self.after_profile = MemorySnapshot(device=device, auto_measure=False)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"Memory profiling takes {self.profile_time:.2f} seconds. " f"Memory profiling takes {self.profile_time:.2f} seconds. "
@@ -175,9 +179,12 @@ class MemoryProfilingResult:
@contextlib.contextmanager @contextlib.contextmanager
def memory_profiling( def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int baseline_snapshot: MemorySnapshot,
weights_memory: int = 0,
) -> Generator[MemoryProfilingResult, None, None]: ) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager. """
Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance. baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights. weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device Note that, before loading the model weights, we also initialize the device
@@ -217,21 +224,24 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2) b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3) c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`. The memory used for loading weights (a.) is directly given from the
argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). The increase of `non_torch_memory` from creating the current vLLM instance
""" # noqa until after profiling to get (c.).
"""
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats() torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult() result = MemoryProfilingResult(
before_create=baseline_snapshot,
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights # the part of memory used for holding the model weights
result.weights_memory = weights_memory weights_memory=weights_memory,
)
result.before_profile.measure() result.before_profile.measure()
@@ -252,4 +262,4 @@ def memory_profiling(
peak_activation_memory = result.torch_peak_increase peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = ( result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa )