[UX][Startup] Account for CUDA graphs during memory profiling (#30515)

This commit is contained in:
Matthew Bonanni
2026-03-07 16:49:23 -05:00
committed by GitHub
parent 85f50eb41f
commit ebb9cc5f2b
6 changed files with 360 additions and 61 deletions

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import weakref
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from typing import Any, ClassVar
from unittest.mock import patch
import torch
@@ -162,6 +163,14 @@ class CUDAGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
_all_instances: ClassVar[weakref.WeakSet["CUDAGraphWrapper"]] = weakref.WeakSet()
@classmethod
def clear_all_graphs(cls) -> None:
"""Clear captured graphs from all CUDAGraphWrapper instances."""
for instance in list(cls._all_instances):
instance.clear_graphs()
def __init__(
self,
runnable: Callable[..., Any],
@@ -192,6 +201,8 @@ class CUDAGraphWrapper:
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}
CUDAGraphWrapper._all_instances.add(self)
def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
@@ -205,6 +216,13 @@ class CUDAGraphWrapper:
# in case we need to access the original runnable.
return self.runnable
@property
def cudagraph_wrapper(self) -> "CUDAGraphWrapper":
return self
def clear_graphs(self) -> None:
self.concrete_cudagraph_entries.clear()
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor

View File

@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
def get_default_cache_root():
@@ -1628,6 +1629,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
),
# If set to 1, enable CUDA graph memory estimation during memory profiling.
# This profiles CUDA graph memory usage to provide more accurate KV cache
# memory allocation. Disabled by default to preserve existing behavior.
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
),
}

View File

@@ -334,8 +334,11 @@ class CudagraphDispatcher:
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
# Sort by (num_tokens, num_active_loras) descending
descs.sort(
key=lambda d: (d.num_tokens, d.num_active_loras),
reverse=True,
)
result.append((mode, descs))
return result

View File

@@ -29,6 +29,7 @@ from vllm.config import (
CUDAGraphMode,
VllmConfig,
get_layers_from_vllm_config,
set_current_vllm_config,
update_config,
)
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
@@ -94,6 +95,7 @@ from vllm.multimodal.inputs import (
PlaceholderRange,
)
from vllm.multimodal.utils import group_and_batch_mm_kwargs
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
@@ -596,6 +598,17 @@ class GPUModelRunner(
self.async_output_copy_stream = torch.cuda.Stream()
self.prepare_inputs_event = torch.Event()
# self.cudagraph_batch_sizes sorts in ascending order.
if (
self.compilation_config.cudagraph_capture_sizes
and self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
self.cudagraph_batch_sizes = sorted(
self.compilation_config.cudagraph_capture_sizes
)
else:
self.cudagraph_batch_sizes = []
# Cache the device properties.
self._init_device_properties()
@@ -4727,6 +4740,7 @@ class GPUModelRunner(
remove_lora: bool = True,
is_graph_capturing: bool = False,
num_active_loras: int = 0,
profile_seq_lens: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Run a dummy forward pass to warm up/profile run or capture the
@@ -4751,6 +4765,9 @@ class GPUModelRunner(
remove_lora: If False, dummy LoRAs are not destroyed after the run
num_active_loras: Number of distinct active LoRAs to capture for.
LoRA is activated when num_active_loras > 0.
profile_seq_lens: If provided, use this value for seq_lens instead
of max_query_len. Used to profile attention workspace that
scales with context length.
"""
mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_encoder_only:
@@ -4881,11 +4898,13 @@ class GPUModelRunner(
# If force_attention is True, we always capture attention.
# Otherwise, it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
if create_mixed_batch:
if profile_seq_lens is not None:
seq_lens = profile_seq_lens # type: ignore[assignment]
elif create_mixed_batch:
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1]
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment]
else:
seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens
@@ -5298,6 +5317,167 @@ class GPUModelRunner(
self.encoder_cache.clear()
gc.collect()
def _init_minimal_kv_cache_for_profiling(self) -> None:
from vllm.v1.core.kv_cache_utils import (
get_kv_cache_config_from_groups,
get_kv_cache_groups,
)
kv_cache_spec = self.get_kv_cache_spec()
kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec)
min_blocks = self.compilation_config.max_cudagraph_capture_size or 1
if kv_cache_groups:
page_size = kv_cache_groups[0].kv_cache_spec.page_size_bytes
group_size = max(len(g.layer_names) for g in kv_cache_groups)
available_memory = min_blocks * page_size * group_size
else:
available_memory = 1 # Attention-free model
minimal_config = get_kv_cache_config_from_groups(
self.vllm_config, kv_cache_groups, available_memory=available_memory
)
self.initialize_kv_cache(minimal_config)
self.cache_config.num_gpu_blocks = minimal_config.num_blocks
logger.debug("Initialized minimal KV cache for CUDA graph profiling")
@staticmethod
@contextmanager
def _freeze_gc():
gc.collect()
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
gc.collect()
def _cleanup_profiling_kv_cache(self) -> None:
torch.accelerator.synchronize()
if hasattr(self, "kv_caches") and self.kv_caches:
for i in range(len(self.kv_caches)):
self.kv_caches[i] = None # type: ignore
self.kv_caches.clear()
if hasattr(self, "cross_layers_kv_cache"):
self.cross_layers_kv_cache = None
self.cross_layers_attn_backend = None
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
delattr(self, "kv_cache_config")
self.cache_config.num_gpu_blocks = None
for layer in self.compilation_config.static_forward_context.values():
if hasattr(layer, "kv_cache"):
layer.kv_cache = []
gc.collect()
torch.accelerator.empty_cache()
logger.debug("Cleaned up profiling KV cache and CUDA graphs")
@torch.inference_mode()
def profile_cudagraph_memory(self) -> int:
with set_current_vllm_config(self.vllm_config):
self._init_minimal_kv_cache_for_profiling()
saved_num_cudagraph_captured = compilation_counter.num_cudagraph_captured
capture_descs = self.cudagraph_dispatcher.get_capture_descs()
total_graphs = sum(len(descs) for _, descs in capture_descs)
if total_graphs == 0:
logger.debug("No CUDA graphs will be captured, skipping profiling")
self._cleanup_profiling_kv_cache()
return 0
logger.info(
"Profiling CUDA graph memory: %s",
", ".join(
f"{mode.name}={len(descs)} (largest={descs[0].num_tokens})"
for mode, descs in capture_descs
if descs
),
)
# Use a temporary pool for profiling to avoid fragmentation in the main pool.
profiling_pool = current_platform.graph_pool_handle()
original_pools: dict[int, Any] = {}
for instance in list(CUDAGraphWrapper._all_instances):
original_pools[id(instance)] = instance.graph_pool
instance.graph_pool = profiling_pool
set_cudagraph_capturing_enabled(True)
with self._freeze_gc(), graph_capture(device=self.device):
shared_memory_estimate = {}
per_graph_estimate = {}
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
for mode, descs in capture_descs:
profile_descs = descs[:2]
mem_samples: list[int] = []
for i, desc in enumerate(profile_descs):
mem_before = torch.cuda.mem_get_info()[0]
self._warmup_and_capture(
desc,
cudagraph_runtime_mode=mode,
profile_seq_lens=(
min(
self.max_model_len,
self.max_num_tokens // desc.num_tokens,
)
if mode == CUDAGraphMode.FULL and i == 0
else None
),
)
torch.accelerator.synchronize()
free_after = torch.cuda.mem_get_info()[0]
mem_samples.append(mem_before - free_after)
first_capture = mem_samples[0]
# Use at least 1 MiB per graph for driver overhead
per_graph = max(mem_samples[1] if len(mem_samples) > 1 else 0, 1 << 20)
shared_memory_estimate[mode] = first_capture
per_graph_estimate[mode] = per_graph * (len(descs) - 1)
logger.debug(
"Estimated %s CUDA graph memory: "
"%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph",
mode.name,
first_capture / (1 << 20),
len(descs),
per_graph / (1 << 20),
)
set_cudagraph_capturing_enabled(False)
CUDAGraphWrapper.clear_all_graphs()
for instance in list(CUDAGraphWrapper._all_instances):
if id(instance) in original_pools:
instance.graph_pool = original_pools[id(instance)]
self.maybe_remove_all_loras(self.lora_config)
self._cleanup_profiling_kv_cache()
compilation_counter.num_cudagraph_captured = saved_num_cudagraph_captured
# FULL and PIECEWISE graphs share the global pool at runtime and are
# never replayed concurrently, so the pool overlays their memory.
# Take the max to avoid double-counting the overlap.
total_estimate = max(shared_memory_estimate.values()) + sum(
per_graph_estimate.values()
)
logger.info(
"Estimated CUDA graph memory: %.2f GiB total",
total_estimate / (1 << 30),
)
return int(total_estimate)
@instrument(span_name="Capture model")
def capture_model(self) -> int:
if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
@@ -5311,27 +5491,13 @@ class GPUModelRunner(
start_time = time.perf_counter()
@contextmanager
def freeze_gc():
# Optimize garbage collection during CUDA graph capture.
# Clean up, then freeze all remaining objects from being included
# in future collections.
gc.collect()
should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
gc.collect()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
with self._freeze_gc(), graph_capture(device=self.device):
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
for (
@@ -5342,6 +5508,7 @@ class GPUModelRunner(
batch_descriptors=batch_descs,
cudagraph_runtime_mode=runtime_mode,
)
torch.accelerator.synchronize()
torch.accelerator.synchronize()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
@@ -5353,6 +5520,9 @@ class GPUModelRunner(
# after here.
set_cudagraph_capturing_enabled(False)
torch.accelerator.synchronize()
torch.accelerator.empty_cache()
# Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling.
lock_workspace()
@@ -5369,6 +5539,40 @@ class GPUModelRunner(
)
return cuda_graph_size
def _warmup_and_capture(
self,
desc: BatchDescriptor,
cudagraph_runtime_mode: CUDAGraphMode,
profile_seq_lens: int | None = None,
allow_microbatching: bool = False,
num_warmups: int | None = None,
):
if num_warmups is None:
num_warmups = self.compilation_config.cudagraph_num_of_warmups
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
for _ in range(num_warmups):
self._dummy_run(
desc.num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=desc.uniform,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
num_active_loras=desc.num_active_loras,
)
self._dummy_run(
desc.num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=desc.uniform,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
num_active_loras=desc.num_active_loras,
is_graph_capturing=True,
profile_seq_lens=profile_seq_lens,
)
def _capture_cudagraphs(
self,
batch_descriptors: list[BatchDescriptor],
@@ -5383,15 +5587,6 @@ class GPUModelRunner(
return
uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
@@ -5406,9 +5601,6 @@ class GPUModelRunner(
# We skip EPLB here since we don't want to record dummy metrics
for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
num_active_loras = batch_desc.num_active_loras
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@@ -5419,33 +5611,16 @@ class GPUModelRunner(
and uniform_decode
and check_ubatch_thresholds(
config=self.vllm_config.parallel_config,
num_tokens=num_tokens,
num_tokens=batch_desc.num_tokens,
uniform_decode=uniform_decode,
)
)
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
allow_microbatching=allow_microbatching,
num_active_loras=num_active_loras,
)
# Capture run
dummy_run(
num_tokens,
self._warmup_and_capture(
batch_desc,
cudagraph_runtime_mode=cudagraph_runtime_mode,
allow_microbatching=allow_microbatching,
num_active_loras=num_active_loras,
is_graph_capturing=True,
)
torch.accelerator.synchronize()
self.maybe_remove_all_loras(self.lora_config)
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:

View File

@@ -112,16 +112,25 @@ class UBatchWrapper:
self.cudagraphs: dict[int, CUDAGraphMetaData] = {}
self.cudagraph_wrapper = None
self.graph_pool = None
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode
)
self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@property
def graph_pool(self):
if self.cudagraph_wrapper is not None:
return self.cudagraph_wrapper.graph_pool
return None
def clear_graphs(self) -> None:
self.cudagraphs.clear()
if self.cudagraph_wrapper is not None:
self.cudagraph_wrapper.clear_graphs()
@staticmethod
def _create_sm_control_context(vllm_config: VllmConfig):
comm_sms: int = envs.VLLM_DBO_COMM_SMS

View File

@@ -44,6 +44,7 @@ from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.tracing import instrument
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -390,8 +391,36 @@ class Worker(WorkerBase):
) as profile_result:
self.model_runner.profile_run()
profile_torch_peak = current_platform.memory_stats(self.device).get(
"allocated_bytes.all.peak", 0
)
# Profile CUDA graph memory if graphs will be captured.
cudagraph_memory_estimate = 0
if not self.model_config.enforce_eager:
cudagraph_memory_estimate = self.model_runner.profile_cudagraph_memory()
# Use the pre-cudagraph torch peak to avoid double-counting.
profile_result.torch_peak_increase = (
profile_torch_peak - profile_result.before_profile.torch_peak
)
profile_result.non_kv_cache_memory = (
profile_result.non_torch_increase
+ profile_result.torch_peak_increase
+ profile_result.weights_memory
)
cudagraph_memory_estimate_applied = (
cudagraph_memory_estimate
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
else 0
)
self.non_torch_memory = profile_result.non_torch_increase
self.peak_activation_memory = profile_result.torch_peak_increase
self.peak_activation_memory = (
profile_result.torch_peak_increase + cudagraph_memory_estimate_applied
)
self.cudagraph_memory_estimate = cudagraph_memory_estimate
free_gpu_memory = profile_result.after_profile.free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
@@ -406,7 +435,9 @@ class Worker(WorkerBase):
"isolate vLLM in its own container."
)
self.available_kv_cache_memory_bytes = (
self.requested_memory - profile_result.non_kv_cache_memory
self.requested_memory
- profile_result.non_kv_cache_memory
- cudagraph_memory_estimate_applied
)
unrequested_memory = self.init_snapshot.free_memory - self.requested_memory
@@ -428,6 +459,46 @@ class Worker(WorkerBase):
scope="local",
)
if cudagraph_memory_estimate > 0:
total_mem = self.init_snapshot.total_memory
current_util = self.cache_config.gpu_memory_utilization
cg_util_delta = cudagraph_memory_estimate / total_mem
if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS:
equiv_util = round(current_util - cg_util_delta, 4)
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"CUDA graph memory profiling is enabled "
"(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
"This will become the default in v0.19. "
"The current --gpu-memory-utilization=%.4f is equivalent "
"to --gpu-memory-utilization=%.4f without CUDA graph "
"memory profiling. To maintain the same effective KV "
"cache size as before, increase "
"--gpu-memory-utilization to %.4f.",
current_util,
equiv_util,
suggested_util,
)
else:
suggested_util = min(
round(current_util + cg_util_delta, 4),
1.0,
)
logger.info(
"In v0.19, CUDA graph memory profiling will be enabled "
"by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
"which more accurately accounts for CUDA graph memory "
"during KV cache allocation. To try it now, set "
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
"--gpu-memory-utilization from %.4f to %.4f to maintain "
"the same effective KV cache size.",
current_util,
suggested_util,
)
return int(self.available_kv_cache_memory_bytes)
def get_kv_connector_handshake_metadata(self) -> dict | None:
@@ -487,14 +558,14 @@ class Worker(WorkerBase):
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float:
warmup_sizes = []
warmup_sizes: list[int] = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
compile_sizes = self.vllm_config.compilation_config.compile_sizes
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else []
warmup_sizes = compile_sizes.copy() if compile_sizes is not None else [] # type: ignore[assignment]
cg_capture_sizes: list[int] = []
if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
@@ -526,6 +597,22 @@ class Worker(WorkerBase):
if not self.model_config.enforce_eager:
cuda_graph_memory_bytes = self.model_runner.capture_model()
# Compare actual vs estimated CUDA graph memory (if we did profiling)
if (
hasattr(self, "cudagraph_memory_estimate")
and self.cudagraph_memory_estimate > 0
):
GiB = lambda b: round(b / GiB_bytes, 2)
diff = abs(cuda_graph_memory_bytes - self.cudagraph_memory_estimate)
logger.info(
"CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), "
"difference: %s GiB (%.1f%%).",
GiB(cuda_graph_memory_bytes),
GiB(self.cudagraph_memory_estimate),
GiB(diff),
100 * diff / max(cuda_graph_memory_bytes, 1),
)
if self.cache_config.kv_cache_memory_bytes is None and hasattr(
self, "peak_activation_memory"
):