[core] overhaul memory profiling and fix backward compatibility (#10511)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
125
vllm/utils.py
125
vllm/utils.py
@@ -23,10 +23,12 @@ import weakref
|
||||
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generic, Hashable, List, Literal, Optional,
|
||||
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
|
||||
Dict, Generator, Generic, Hashable, List, Literal,
|
||||
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@@ -1664,3 +1666,122 @@ def kill_process_tree(pid: int):
|
||||
# Finally kill the parent
|
||||
with contextlib.suppress(ProcessLookupError):
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySnapshot:
|
||||
"""Memory snapshot."""
|
||||
torch_peak_in_bytes: int = 0
|
||||
torch_memory_in_bytes: int = 0
|
||||
timestamp: float = 0.0
|
||||
|
||||
def measure(self):
|
||||
self.torch_peak_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.peak"]
|
||||
self.torch_memory_in_bytes = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
"""support a - b"""
|
||||
return MemorySnapshot(
|
||||
torch_peak_in_bytes=self.torch_peak_in_bytes -
|
||||
other.torch_peak_in_bytes,
|
||||
torch_memory_in_bytes=self.torch_memory_in_bytes -
|
||||
other.torch_memory_in_bytes,
|
||||
timestamp=self.timestamp - other.timestamp)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryProfilingResult:
|
||||
"""Memory profiling result.
|
||||
""" # noqa
|
||||
baseline_memory_in_bytes: int = 0
|
||||
non_kv_cache_memory_in_bytes: int = 0
|
||||
torch_peak_increase_in_bytes: int = 0
|
||||
non_torch_increase_in_bytes: int = 0
|
||||
weights_memory_in_bytes: float = 0
|
||||
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
|
||||
profile_time: float = 0.0
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def memory_profiling(
|
||||
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
|
||||
) -> Generator[MemoryProfilingResult, None, None]:
|
||||
"""Memory profiling context manager.
|
||||
baseline_memory_in_bytes: memory used by all the components other than
|
||||
the current vLLM instance. It contains: memory used by other processes, memory
|
||||
used by another vLLM instance in the same process, etc. It is usually measured
|
||||
before the current vLLM instance initialize the device. And we assume it is
|
||||
constant during the profiling of the current vLLM instance.
|
||||
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
|
||||
Note that, before loading the model weights, we also initialize the device
|
||||
and distributed environment, which may consume some memory. This part is not
|
||||
included in the weights_memory_in_bytes because PyTorch does not control it.
|
||||
|
||||
The memory in one GPU can be classified into 3 categories:
|
||||
1. memory used by anything other than the current vLLM instance.
|
||||
2. memory used by torch in the current vLLM instance.
|
||||
3. memory used in the current vLLM instance, but not by torch.
|
||||
|
||||
A quantitive example:
|
||||
|
||||
Before creating the current vLLM instance:
|
||||
category 1: 1 GiB
|
||||
category 2: 0 GiB
|
||||
category 3: 0 GiB
|
||||
|
||||
After creating the current vLLM instance and loading the model,
|
||||
(i.e. before profiling):
|
||||
category 1: 1 GiB
|
||||
category 2: 2 GiB (model weights take 2 GiB)
|
||||
category 3: 0.5 GiB (memory used by NCCL)
|
||||
|
||||
During profiling (peak):
|
||||
category 1: 1 GiB
|
||||
category 2: 4 GiB (peak activation tensors take 2 GiB)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
After profiling:
|
||||
category 1: 1 GiB
|
||||
category 2: 3 GiB (after garbage-collecting activation tensors)
|
||||
category 3: 1 GiB (memory used by NCCL + buffers for some attention backends)
|
||||
|
||||
In this case, non-kv cache takes 5 GiB in total, including:
|
||||
a. 2 GiB used by the model weights (category 2)
|
||||
b. 2 GiB reserved for the peak activation tensors (category 2)
|
||||
c. 1 GiB used by non-torch components (category 3)
|
||||
|
||||
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
|
||||
|
||||
The increase of ``torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
|
||||
|
||||
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
|
||||
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_stats()["allocated_bytes.all.current"]`.
|
||||
""" # noqa
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
result = MemoryProfilingResult()
|
||||
|
||||
result.baseline_memory_in_bytes = baseline_memory_in_bytes
|
||||
# the part of memory used for holding the model weights
|
||||
result.weights_memory_in_bytes = weights_memory_in_bytes
|
||||
|
||||
result.before_profile.measure()
|
||||
|
||||
yield result
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
result.after_profile.measure()
|
||||
|
||||
diff = result.after_profile - result.before_profile
|
||||
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
|
||||
current_cuda_memory_bytes = torch.cuda.mem_get_info(
|
||||
)[1] - torch.cuda.mem_get_info()[0]
|
||||
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
|
||||
result.profile_time = diff.timestamp
|
||||
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
|
||||
|
||||
Reference in New Issue
Block a user