[UX][Startup] Account for CUDA graphs during memory profiling (#30515)
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -390,8 +391,36 @@ class Worker(WorkerBase):
|
||||
) as profile_result:
|
||||
self.model_runner.profile_run()
|
||||
|
||||
profile_torch_peak = current_platform.memory_stats(self.device).get(
|
||||
"allocated_bytes.all.peak", 0
|
||||
)
|
||||
|
||||
# Profile CUDA graph memory if graphs will be captured.
|
||||
cudagraph_memory_estimate = 0
|
||||
if not self.model_config.enforce_eager:
|
||||
cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()
|
||||
|
||||
# Use the pre-cudagraph torch peak to avoid double-counting.
|
||||
profile_result.torch_peak_increase = (
|
||||
profile_torch_peak - profile_result.before_profile.torch_peak
|
||||
)
|
||||
profile_result.non_kv_cache_memory = (
|
||||
profile_result.non_torch_increase
|
||||
+ profile_result.torch_peak_increase
|
||||
+ profile_result.weights_memory
|
||||
)
|
||||
|
||||
cudagraph_memory_estimate_applied = (
|
||||
cudagraph_memory_estimate
|
||||
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
|
||||
else 0
|
||||
)
|
||||
|
||||
self.non_torch_memory = profile_result.non_torch_increase
|
||||
self.peak_activation_memory = profile_result.torch_peak_increase
|
||||
self.peak_activation_memory = (
|
||||
profile_result.torch_peak_increase + cudagraph_memory_estimate_applied
|
||||
)
|
||||
self.cudagraph_memory_estimate = cudagraph_memory_estimate
|
||||
|
||||
free_gpu_memory = profile_result.after_profile.free_memory
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
@@ -406,7 +435,9 @@ class Worker(WorkerBase):
|
||||
"isolate vLLM in its own container."
|
||||
)
|
||||
self.available_kv_cache_memory_bytes = (
|
||||
self.requested_memory - profile_result.non_kv_cache_memory
|
||||
self.requested_memory
|
||||
- profile_result.non_kv_cache_memory
|
||||
- cudagraph_memory_estimate_applied
|
||||
)
|
||||
|
||||
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
|
||||
@@ -428,6 +459,46 @@ class Worker(WorkerBase):
|
||||
scope="local",
|
||||
)
|
||||
|
||||
if cudagraph_memory_estimate > 0:
|
||||
total_mem = self.init_snapshot.total_memory
|
||||
current_util = self.cache_config.gpu_memory_utilization
|
||||
cg_util_delta = cudagraph_memory_estimate / total_mem
|
||||
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS:
|
||||
equiv_util = round(current_util - cg_util_delta, 4)
|
||||
suggested_util = min(
|
||||
round(current_util + cg_util_delta, 4),
|
||||
1.0,
|
||||
)
|
||||
logger.info(
|
||||
"CUDA graph memory profiling is enabled "
|
||||
"(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
|
||||
"This will become the default in v0.19. "
|
||||
"The current --gpu-memory-utilization=%.4f is equivalent "
|
||||
"to --gpu-memory-utilization=%.4f without CUDA graph "
|
||||
"memory profiling. To maintain the same effective KV "
|
||||
"cache size as before, increase "
|
||||
"--gpu-memory-utilization to %.4f.",
|
||||
current_util,
|
||||
equiv_util,
|
||||
suggested_util,
|
||||
)
|
||||
else:
|
||||
suggested_util = min(
|
||||
round(current_util + cg_util_delta, 4),
|
||||
1.0,
|
||||
)
|
||||
logger.info(
|
||||
"In v0.19, CUDA graph memory profiling will be enabled "
|
||||
"by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
|
||||
"which more accurately accounts for CUDA graph memory "
|
||||
"during KV cache allocation. To try it now, set "
|
||||
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
|
||||
"--gpu-memory-utilization from %.4f to %.4f to maintain "
|
||||
"the same effective KV cache size.",
|
||||
current_util,
|
||||
suggested_util,
|
||||
)
|
||||
|
||||
return int(self.available_kv_cache_memory_bytes)
|
||||
|
||||
def get_kv_connector_handshake_metadata(self) -> dict | None:
|
||||
@@ -487,14 +558,14 @@ class Worker(WorkerBase):
|
||||
|
||||
@instrument(span_name="Warmup (GPU)")
|
||||
def compile_or_warm_up_model(self) -> float:
|
||||
warmup_sizes = []
|
||||
warmup_sizes: list[int] = []
|
||||
|
||||
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
# warm up sizes that are not in cudagraph capture sizes,
|
||||
# but users still want to compile for better performance,
|
||||
# e.g. for the max-num-batched token size in chunked prefill.
|
||||
compile_sizes = self.vllm_config.compilation_config.compile_sizes
|
||||
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
|
||||
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] # type: ignore[assignment]
|
||||
cg_capture_sizes: list[int] = []
|
||||
|
||||
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
@@ -526,6 +597,22 @@ class Worker(WorkerBase):
|
||||
if not self.model_config.enforce_eager:
|
||||
cuda_graph_memory_bytes = self.model_runner.capture_model()
|
||||
|
||||
# Compare actual vs estimated CUDA graph memory (if we did profiling)
|
||||
if (
|
||||
hasattr(self, "cudagraph_memory_estimate")
|
||||
and self.cudagraph_memory_estimate > 0
|
||||
):
|
||||
GiB = lambda b: round(b / GiB_bytes, 2)
|
||||
diff = abs(cuda_graph_memory_bytes - self.cudagraph_memory_estimate)
|
||||
logger.info(
|
||||
"CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), "
|
||||
"difference: %s GiB (%.1f%%).",
|
||||
GiB(cuda_graph_memory_bytes),
|
||||
GiB(self.cudagraph_memory_estimate),
|
||||
GiB(diff),
|
||||
100 * diff / max(cuda_graph_memory_bytes, 1),
|
||||
)
|
||||
|
||||
if self.cache_config.kv_cache_memory_bytes is None and hasattr(
|
||||
self, "peak_activation_memory"
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user