[cudagraphs] Refactor cudagraph capture loop (#32946)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -231,3 +231,26 @@ class CudagraphDispatcher:
|
||||
|
||||
# finally, just return no cudagraphs and a trivial batch descriptor
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
|
||||
"""
|
||||
Returns capture descriptors for cudagraph capturing.
|
||||
|
||||
Returns:
|
||||
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
|
||||
first then FULL. Batch descriptors are sorted largest-first for
|
||||
memory efficiency.
|
||||
"""
|
||||
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return []
|
||||
|
||||
result = []
|
||||
# Return in order: PIECEWISE first, then FULL
|
||||
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
|
||||
descs = list(self.cudagraph_keys[mode])
|
||||
if descs:
|
||||
# Sort by num_tokens descending (largest first)
|
||||
descs.sort(key=lambda d: d.num_tokens, reverse=True)
|
||||
result.append((mode, descs))
|
||||
|
||||
return result
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -4839,50 +4838,14 @@ class GPUModelRunner(
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
|
||||
if self.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
# make sure we capture the largest batch size first
|
||||
compilation_cases = list(
|
||||
product(reversed(self.cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
for (
|
||||
runtime_mode,
|
||||
batch_descs,
|
||||
) in self.cudagraph_dispatcher.get_capture_descs():
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=False,
|
||||
)
|
||||
|
||||
# Capture full cudagraph for uniform decode batches if we
|
||||
# don't already have full mixed prefill-decode cudagraphs.
|
||||
if (
|
||||
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and cudagraph_mode.separate_routine()
|
||||
):
|
||||
max_num_tokens = (
|
||||
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
|
||||
)
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x
|
||||
for x in self.cudagraph_batch_sizes
|
||||
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True,
|
||||
batch_descriptors=batch_descs,
|
||||
cudagraph_runtime_mode=runtime_mode,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -4913,19 +4876,32 @@ class GPUModelRunner(
|
||||
|
||||
def _capture_cudagraphs(
|
||||
self,
|
||||
compilation_cases: list[tuple[int, bool]],
|
||||
batch_descriptors: list[BatchDescriptor],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool,
|
||||
):
|
||||
assert (
|
||||
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
and cudagraph_runtime_mode.valid_runtime_modes()
|
||||
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
|
||||
|
||||
if not batch_descriptors:
|
||||
return
|
||||
|
||||
uniform_decode = batch_descriptors[0].uniform
|
||||
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
|
||||
dummy_run = functools.partial(
|
||||
self._dummy_run,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
force_attention=force_attention,
|
||||
)
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(
|
||||
compilation_cases,
|
||||
batch_descriptors = tqdm(
|
||||
batch_descriptors,
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
@@ -4934,7 +4910,10 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens, activate_lora in compilation_cases:
|
||||
for batch_desc in batch_descriptors:
|
||||
num_tokens = batch_desc.num_tokens
|
||||
activate_lora = batch_desc.has_lora
|
||||
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
@@ -4952,28 +4931,22 @@ class GPUModelRunner(
|
||||
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# But be careful, warm up with `NONE` is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
self._dummy_run(
|
||||
dummy_run(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
)
|
||||
self._dummy_run(
|
||||
|
||||
# Capture run
|
||||
dummy_run(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user