Allow users to specify kv cache memory size (#21489)
Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -231,18 +231,40 @@ class Worker(WorkerBase):
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
||||
# still need a profile run which compiles the model for
|
||||
# max_num_batched_tokens
|
||||
self.model_runner.profile_run()
|
||||
|
||||
msg = (
|
||||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} "
|
||||
f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f}GiB memory for "
|
||||
"KV Cache as specified by kv_cache_memory_bytes config and "
|
||||
"skipped memory profiling. This does does not respect the "
|
||||
"gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
||||
"config when you want manual control of KV cache memory "
|
||||
"size. If OOM'ed, check the difference of initial free "
|
||||
"memory between the current run and the previous run "
|
||||
"where kv_cache_memory_bytes is suggested and update it "
|
||||
"correspondingly.")
|
||||
logger.info(msg)
|
||||
return kv_cache_memory_bytes
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
GiB = lambda b: b / GiB_bytes
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
with memory_profiling(
|
||||
self.init_snapshot,
|
||||
weights_memory=int(
|
||||
self.model_runner.model_memory_usage)) as profile_result:
|
||||
weights_memory=int(self.model_runner.model_memory_usage),
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
self.non_torch_memory = profile_result.non_torch_increase
|
||||
self.peak_activation_memory = profile_result.torch_peak_increase
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
@@ -254,7 +276,7 @@ class Worker(WorkerBase):
|
||||
"release GPU memory while vLLM is profiling during initialization. "
|
||||
"To fix this, ensure consistent GPU memory allocation or "
|
||||
"isolate vLLM in its own container.")
|
||||
available_kv_cache_memory = self.requested_memory \
|
||||
self.available_kv_cache_memory_bytes = self.requested_memory \
|
||||
- profile_result.non_kv_cache_memory
|
||||
|
||||
unrequested_memory = self.init_snapshot.free_memory \
|
||||
@@ -274,10 +296,10 @@ class Worker(WorkerBase):
|
||||
)
|
||||
logger.debug(profile_result)
|
||||
logger.info("Available KV cache memory: %.2f GiB",
|
||||
GiB(available_kv_cache_memory))
|
||||
GiB(self.available_kv_cache_memory_bytes))
|
||||
gc.collect()
|
||||
|
||||
return int(available_kv_cache_memory)
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
return self.model_runner.get_kv_cache_spec()
|
||||
@@ -317,8 +339,56 @@ class Worker(WorkerBase):
|
||||
# cuda graph capture.
|
||||
kernel_warmup(self)
|
||||
|
||||
cuda_graph_memory_bytes = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
self.model_runner.capture_model()
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
||||
|
||||
if (self.cache_config.kv_cache_memory_bytes is None
|
||||
and hasattr(self, "peak_activation_memory")):
|
||||
# Suggests optimal kv cache memory size if we rely on
|
||||
# memory_profiling to guess the kv cache memory size which
|
||||
# provides peak_activation_memory and a few other memory
|
||||
# consumption. `memory_profiling` does not consider
|
||||
# CUDAGraph memory size and may not utilize all gpu memory.
|
||||
# Users may want fine-grained control to specify kv cache
|
||||
# memory size.
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
|
||||
# empirically observed that the memory profiling may
|
||||
# slightly underestimate the memory consumption.
|
||||
# So leave a small buffer (=150MiB) to avoid OOM.
|
||||
redundancy_buffer_memory = 150 * (1 << 20)
|
||||
non_kv_cache_memory = (self.model_runner.model_memory_usage +
|
||||
self.peak_activation_memory +
|
||||
self.non_torch_memory +
|
||||
cuda_graph_memory_bytes)
|
||||
kv_cache_memory_bytes_to_gpu_limit = (
|
||||
self.init_snapshot.free_memory - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
kv_cache_memory_bytes_to_requested_limit = (
|
||||
int(self.requested_memory) - non_kv_cache_memory -
|
||||
redundancy_buffer_memory)
|
||||
|
||||
msg = (
|
||||
f"Free memory on device "
|
||||
f"({GiB(self.init_snapshot.free_memory)}/"
|
||||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup. "
|
||||
f"Desired GPU memory utilization is "
|
||||
f"({self.cache_config.gpu_memory_utilization}, "
|
||||
f"{GiB(self.requested_memory)} GiB). "
|
||||
f"Actual usage is {GiB(self.model_runner.model_memory_usage)} "
|
||||
f"GiB for weight, {GiB(self.peak_activation_memory)} GiB "
|
||||
f"for peak activation, {GiB(self.non_torch_memory)} GiB "
|
||||
f"for non-torch memory, and {GiB(cuda_graph_memory_bytes)} "
|
||||
f"GiB for CUDAGraph memory. Replace gpu_memory_utilization "
|
||||
f"config with `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_requested_limit}` to fit into "
|
||||
f"requested memory, or `--kv-cache-memory="
|
||||
f"{kv_cache_memory_bytes_to_gpu_limit}` to fully "
|
||||
f"utilize gpu memory. Current kv cache memory in use is "
|
||||
f"{int(self.available_kv_cache_memory_bytes)} bytes.")
|
||||
|
||||
logger.info(msg)
|
||||
|
||||
# Warm up sampler and preallocate memory buffer for logits and other
|
||||
# sampling related tensors of max possible shape to avoid memory
|
||||
|
||||
Reference in New Issue
Block a user