[cudagraphs] Refactor cudagraph capture loop (#32946)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-23 13:22:20 -07:00
committed by GitHub
parent 8518b30447
commit 3a41459501
3 changed files with 117 additions and 59 deletions

View File

@@ -173,6 +173,68 @@ class TestCudagraphDispatcher:
else:
assert rt_mode == CUDAGraphMode.NONE
@pytest.mark.parametrize(
"cudagraph_mode_str,compilation_mode,expected_modes",
[
# FULL mode: only FULL keys, no PIECEWISE
("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# PIECEWISE mode: only PIECEWISE keys
("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
# FULL_DECODE_ONLY: only FULL keys for uniform decode
("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# NONE mode: no keys
("NONE", CompilationMode.NONE, []),
],
)
def test_get_capture_descs(
self, cudagraph_mode_str, compilation_mode, expected_modes
):
"""Test get_capture_descs returns correctly grouped and ordered descs."""
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
mode=compilation_mode,
cudagraph_capture_sizes=[1, 4, 8, 16],
)
config = _create_vllm_config(comp_config, max_num_seqs=16)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
)
capture_descs = dispatcher.get_capture_descs()
# Verify we get the expected modes
actual_modes = [mode for mode, _ in capture_descs]
assert actual_modes == expected_modes
# Verify each group is sorted largest-first
for mode, descs in capture_descs:
assert len(descs) > 0, "Each group should have at least one descriptor"
num_tokens_list = [d.num_tokens for d in descs]
assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
f"Descriptors for {mode} should be sorted largest-first"
)
# All descriptors in a group should have same uniform value
uniform_values = [d.uniform for d in descs]
assert len(set(uniform_values)) == 1, (
"All descriptors in a group should have the same uniform value"
)
def test_get_capture_descs_empty_when_not_initialized(self):
"""Test that get_capture_descs returns empty list when keys not initialized."""
comp_config = CompilationConfig(
cudagraph_mode="FULL",
mode=CompilationMode.NONE,
cudagraph_capture_sizes=[1, 8],
)
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
# Don't initialize keys
assert dispatcher.get_capture_descs() == []
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:

View File

@@ -231,3 +231,26 @@ class CudagraphDispatcher:
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
"""
Returns capture descriptors for cudagraph capturing.
Returns:
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
first then FULL. Batch descriptors are sorted largest-first for
memory efficiency.
"""
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
return []
result = []
# Return in order: PIECEWISE first, then FULL
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
result.append((mode, descs))
return result

View File

@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np
@@ -4839,50 +4838,14 @@ class GPUModelRunner(
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
for (
runtime_mode,
batch_descs,
) in self.cudagraph_dispatcher.get_capture_descs():
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False,
)
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
):
max_num_tokens = (
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
)
decode_cudagraph_batch_sizes = [
x
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True,
batch_descriptors=batch_descs,
cudagraph_runtime_mode=runtime_mode,
)
torch.cuda.synchronize()
@@ -4913,19 +4876,32 @@ class GPUModelRunner(
def _capture_cudagraphs(
self,
compilation_cases: list[tuple[int, bool]],
batch_descriptors: list[BatchDescriptor],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
if not batch_descriptors:
return
uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
batch_descriptors = tqdm(
batch_descriptors,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
@@ -4934,7 +4910,10 @@ class GPUModelRunner(
)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases:
for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
activate_lora = batch_desc.has_lora
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@@ -4952,28 +4931,22 @@ class GPUModelRunner(
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
self._dummy_run(
dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self._dummy_run(
# Capture run
dummy_run(
num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
is_graph_capturing=True,
)