[Chore] Cleanup mem_utils.py (#31793)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -22,7 +22,7 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
|||||||
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
max_shared_mem = ops.get_max_shared_memory_per_block_device_attribute(gpu)
|
||||||
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
# value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
|
||||||
# will fail
|
# will fail
|
||||||
assert max_shared_mem > 0, "max_shared_mem can not be zero"
|
assert max_shared_mem > 0, "max_shared_mem cannot be zero"
|
||||||
return int(max_shared_mem)
|
return int(max_shared_mem)
|
||||||
|
|
||||||
|
|
||||||
@@ -154,12 +154,16 @@ class MemoryProfilingResult:
|
|||||||
non_kv_cache_memory: int = 0
|
non_kv_cache_memory: int = 0
|
||||||
torch_peak_increase: int = 0
|
torch_peak_increase: int = 0
|
||||||
non_torch_increase: int = 0
|
non_torch_increase: int = 0
|
||||||
weights_memory: float = 0
|
weights_memory: int = 0
|
||||||
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
||||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
|
||||||
profile_time: float = 0.0
|
profile_time: float = 0.0
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
device = self.before_create.device_
|
||||||
|
|
||||||
|
self.before_profile = MemorySnapshot(device=device, auto_measure=False)
|
||||||
|
self.after_profile = MemorySnapshot(device=device, auto_measure=False)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
f"Memory profiling takes {self.profile_time:.2f} seconds. "
|
||||||
@@ -175,9 +179,12 @@ class MemoryProfilingResult:
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def memory_profiling(
|
def memory_profiling(
|
||||||
baseline_snapshot: MemorySnapshot, weights_memory: int
|
baseline_snapshot: MemorySnapshot,
|
||||||
|
weights_memory: int = 0,
|
||||||
) -> Generator[MemoryProfilingResult, None, None]:
|
) -> Generator[MemoryProfilingResult, None, None]:
|
||||||
"""Memory profiling context manager.
|
"""
|
||||||
|
Memory profiling context manager.
|
||||||
|
|
||||||
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
baseline_snapshot: the memory snapshot before the current vLLM instance.
|
||||||
weights_memory: memory used by PyTorch when loading the model weights.
|
weights_memory: memory used by PyTorch when loading the model weights.
|
||||||
Note that, before loading the model weights, we also initialize the device
|
Note that, before loading the model weights, we also initialize the device
|
||||||
@@ -217,21 +224,24 @@ def memory_profiling(
|
|||||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||||
c. 1 GiB used by non-torch components (category 3)
|
c. 1 GiB used by non-torch components (category 3)
|
||||||
|
|
||||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
|
The memory used for loading weights (a.) is directly given from the
|
||||||
|
argument `weights_memory`.
|
||||||
|
|
||||||
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
|
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]`
|
||||||
|
during profiling gives (b.).
|
||||||
|
|
||||||
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
|
The increase of `non_torch_memory` from creating the current vLLM instance
|
||||||
""" # noqa
|
until after profiling to get (c.).
|
||||||
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.reset_peak_memory_stats()
|
torch.cuda.reset_peak_memory_stats(baseline_snapshot.device_)
|
||||||
|
|
||||||
result = MemoryProfilingResult()
|
result = MemoryProfilingResult(
|
||||||
|
before_create=baseline_snapshot,
|
||||||
result.before_create = baseline_snapshot
|
|
||||||
# the part of memory used for holding the model weights
|
# the part of memory used for holding the model weights
|
||||||
result.weights_memory = weights_memory
|
weights_memory=weights_memory,
|
||||||
|
)
|
||||||
|
|
||||||
result.before_profile.measure()
|
result.before_profile.measure()
|
||||||
|
|
||||||
@@ -252,4 +262,4 @@ def memory_profiling(
|
|||||||
peak_activation_memory = result.torch_peak_increase
|
peak_activation_memory = result.torch_peak_increase
|
||||||
result.non_kv_cache_memory = (
|
result.non_kv_cache_memory = (
|
||||||
non_torch_memory + peak_activation_memory + result.weights_memory
|
non_torch_memory + peak_activation_memory + result.weights_memory
|
||||||
) # noqa
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user