[Chore] Cleanup mem_utils.py (#31793)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
||||
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
||||
# will fail
|
||||
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
||||
assert max_shared_mem > 0, "max_shared_mem cannot be zero"
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
@@ -154,12 +154,16 @@ class MemoryProfilingResult:
|
||||
non_kv_cache_memory: int = 0
|
||||
torch_peak_increase: int = 0
|
||||
non_torch_increase: int = 0
|
||||
weights_memory: float = 0
|
||||
weights_memory: int = 0
|
||||
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
profile_time: float = 0.0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
device = self.before_create.device_
|
||||
|
||||
self.before_profile = MemorySnapshot(device=device, auto_measure=False)
|
||||
self.after_profile = MemorySnapshot(device=device, auto_measure=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
||||
@@ -175,9 +179,12 @@ class MemoryProfilingResult:
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memory_profiling(
|
||||
baseline_snapshot: MemorySnapshot, weights_memory: int
|
||||
baseline_snapshot: MemorySnapshot,
|
||||
weights_memory: int = 0,
|
||||
) -> Generator[MemoryProfilingResult, None, None]:
|
||||
"""Memory profiling context manager.
|
||||
"""
|
||||
Memory profiling context manager.
|
||||
|
||||
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||
weights_memory: memory used by PyTorch when loading the model weights.
|
||||
Note that, before loading the model weights, we also initialize the device
|
||||
@@ -217,21 +224,24 @@ def memory_profiling(
|
||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||
c. 1 GiB used by non-torch components (category 3)
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
||||
The memory used for loading weights (a.) is directly given from the
|
||||
argument `weights_memory`.
|
||||
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
|
||||
during profiling gives (b.).
|
||||
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
||||
""" # noqa
|
||||
The increase of `non_torch_memory` from creating the current vLLM instance
|
||||
until after profiling to get (c.).
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||
|
||||
result = MemoryProfilingResult()
|
||||
|
||||
result.before_create = baseline_snapshot
|
||||
# the part of memory used for holding the model weights
|
||||
result.weights_memory = weights_memory
|
||||
result = MemoryProfilingResult(
|
||||
before_create=baseline_snapshot,
|
||||
# the part of memory used for holding the model weights
|
||||
weights_memory=weights_memory,
|
||||
)
|
||||
|
||||
result.before_profile.measure()
|
||||
|
||||
@@ -252,4 +262,4 @@ def memory_profiling(
|
||||
peak_activation_memory = result.torch_peak_increase
|
||||
result.non_kv_cache_memory = (
|
||||
non_torch_memory + peak_activation_memory + result.weights_memory
|
||||
) # noqa
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user