[Chore] Cleanup mem_utils.py (#31793)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-06 19:55:59 +08:00
committed by GitHub
parent 6ebb66ccea
commit 14df02b4e1

View File

@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
# will fail
assert max_shared_mem > 0, "max_shared_mem can not be zero"
assert max_shared_mem > 0, "max_shared_mem cannot be zero"
return int(max_shared_mem)
@@ -154,12 +154,16 @@ class MemoryProfilingResult:
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
weights_memory: int = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0
def __post_init__(self) -> None:
device = self.before_create.device_
self.before_profile = MemorySnapshot(device=device, auto_measure=False)
self.after_profile = MemorySnapshot(device=device, auto_measure=False)
def __repr__(self) -> str:
return (
f"Memory profiling takes {self.profile_time:.2f} seconds. "
@@ -175,9 +179,12 @@ class MemoryProfilingResult:
@contextlib.contextmanager
def memory_profiling(
baseline_snapshot: MemorySnapshot, weights_memory: int
baseline_snapshot: MemorySnapshot,
weights_memory: int = 0,
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
"""
Memory profiling context manager.
baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
@@ -217,21 +224,24 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The memory used for loading weights (a.) is directly given from the
argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
during profiling gives (b.).
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
The increase of `non_torch_memory` from creating the current vLLM instance
until after profiling to get (c.).
"""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
result = MemoryProfilingResult()
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory = weights_memory
result = MemoryProfilingResult(
before_create=baseline_snapshot,
# the part of memory used for holding the model weights
weights_memory=weights_memory,
)
result.before_profile.measure()
@@ -252,4 +262,4 @@ def memory_profiling(
peak_activation_memory = result.torch_peak_increase
result.non_kv_cache_memory = (
non_torch_memory + peak_activation_memory + result.weights_memory
) # noqa
)