[cudagraphs] Refactor cudagraph capture loop (#32946)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-01-23 13:22:20 -07:00
committed by GitHub
parent 8518b30447
commit 3a41459501
3 changed files with 117 additions and 59 deletions

View File

@@ -231,3 +231,26 @@ class CudagraphDispatcher:
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
"""
Returns capture descriptors for cudagraph capturing.
Returns:
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
first then FULL. Batch descriptors are sorted largest-first for
memory efficiency.
"""
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
return []
result = []
# Return in order: PIECEWISE first, then FULL
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
result.append((mode, descs))
return result

View File

@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
import numpy as np
@@ -4839,50 +4838,14 @@ class GPUModelRunner(
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None
if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
for (
runtime_mode,
batch_descs,
) in self.cudagraph_dispatcher.get_capture_descs():
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False,
)
# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
):
max_num_tokens = (
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
)
decode_cudagraph_batch_sizes = [
x
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True,
batch_descriptors=batch_descs,
cudagraph_runtime_mode=runtime_mode,
)
torch.cuda.synchronize()
@@ -4913,19 +4876,32 @@ class GPUModelRunner(
def _capture_cudagraphs(
self,
compilation_cases: list[tuple[int, bool]],
batch_descriptors: list[BatchDescriptor],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
if not batch_descriptors:
return
uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)
# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
batch_descriptors = tqdm(
batch_descriptors,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
@@ -4934,7 +4910,10 @@ class GPUModelRunner(
)
# We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases:
for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
activate_lora = batch_desc.has_lora
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
@@ -4952,28 +4931,22 @@ class GPUModelRunner(
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
self._dummy_run(
dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self._dummy_run(
# Capture run
dummy_run(
num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
is_graph_capturing=True,
)