[cudagraphs] Refactor cudagraph capture loop (#32946)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -173,6 +173,68 @@ class TestCudagraphDispatcher:
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cudagraph_mode_str,compilation_mode,expected_modes",
|
||||
[
|
||||
# FULL mode: only FULL keys, no PIECEWISE
|
||||
("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
|
||||
# PIECEWISE mode: only PIECEWISE keys
|
||||
("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
|
||||
# FULL_DECODE_ONLY: only FULL keys for uniform decode
|
||||
("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
|
||||
# NONE mode: no keys
|
||||
("NONE", CompilationMode.NONE, []),
|
||||
],
|
||||
)
|
||||
def test_get_capture_descs(
|
||||
self, cudagraph_mode_str, compilation_mode, expected_modes
|
||||
):
|
||||
"""Test get_capture_descs returns correctly grouped and ordered descs."""
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode=cudagraph_mode_str,
|
||||
mode=compilation_mode,
|
||||
cudagraph_capture_sizes=[1, 4, 8, 16],
|
||||
)
|
||||
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=16)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
dispatcher.initialize_cudagraph_keys(
|
||||
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
|
||||
)
|
||||
|
||||
capture_descs = dispatcher.get_capture_descs()
|
||||
|
||||
# Verify we get the expected modes
|
||||
actual_modes = [mode for mode, _ in capture_descs]
|
||||
assert actual_modes == expected_modes
|
||||
|
||||
# Verify each group is sorted largest-first
|
||||
for mode, descs in capture_descs:
|
||||
assert len(descs) > 0, "Each group should have at least one descriptor"
|
||||
num_tokens_list = [d.num_tokens for d in descs]
|
||||
assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
|
||||
f"Descriptors for {mode} should be sorted largest-first"
|
||||
)
|
||||
|
||||
# All descriptors in a group should have same uniform value
|
||||
uniform_values = [d.uniform for d in descs]
|
||||
assert len(set(uniform_values)) == 1, (
|
||||
"All descriptors in a group should have the same uniform value"
|
||||
)
|
||||
|
||||
def test_get_capture_descs_empty_when_not_initialized(self):
|
||||
"""Test that get_capture_descs returns empty list when keys not initialized."""
|
||||
comp_config = CompilationConfig(
|
||||
cudagraph_mode="FULL",
|
||||
mode=CompilationMode.NONE,
|
||||
cudagraph_capture_sizes=[1, 8],
|
||||
)
|
||||
config = _create_vllm_config(comp_config, max_num_seqs=8)
|
||||
dispatcher = CudagraphDispatcher(config)
|
||||
# Don't initialize keys
|
||||
|
||||
assert dispatcher.get_capture_descs() == []
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestCUDAGraphWrapper:
|
||||
|
||||
@@ -231,3 +231,26 @@ class CudagraphDispatcher:
|
||||
|
||||
# finally, just return no cudagraphs and a trivial batch descriptor
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
|
||||
"""
|
||||
Returns capture descriptors for cudagraph capturing.
|
||||
|
||||
Returns:
|
||||
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
|
||||
first then FULL. Batch descriptors are sorted largest-first for
|
||||
memory efficiency.
|
||||
"""
|
||||
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return []
|
||||
|
||||
result = []
|
||||
# Return in order: PIECEWISE first, then FULL
|
||||
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
|
||||
descs = list(self.cudagraph_keys[mode])
|
||||
if descs:
|
||||
# Sort by num_tokens descending (largest first)
|
||||
descs.sort(key=lambda d: d.num_tokens, reverse=True)
|
||||
result.append((mode, descs))
|
||||
|
||||
return result
|
||||
|
||||
@@ -10,7 +10,6 @@ from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from functools import reduce
|
||||
from itertools import product
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -4839,50 +4838,14 @@ class GPUModelRunner(
|
||||
set_cudagraph_capturing_enabled(True)
|
||||
with freeze_gc(), graph_capture(device=self.device):
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
|
||||
if self.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
|
||||
# make sure we capture the largest batch size first
|
||||
compilation_cases = list(
|
||||
product(reversed(self.cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
for (
|
||||
runtime_mode,
|
||||
batch_descs,
|
||||
) in self.cudagraph_dispatcher.get_capture_descs():
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=False,
|
||||
)
|
||||
|
||||
# Capture full cudagraph for uniform decode batches if we
|
||||
# don't already have full mixed prefill-decode cudagraphs.
|
||||
if (
|
||||
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and cudagraph_mode.separate_routine()
|
||||
):
|
||||
max_num_tokens = (
|
||||
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
|
||||
)
|
||||
decode_cudagraph_batch_sizes = [
|
||||
x
|
||||
for x in self.cudagraph_batch_sizes
|
||||
if max_num_tokens >= x >= self.uniform_decode_query_len
|
||||
]
|
||||
compilation_cases_decode = list(
|
||||
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
|
||||
)
|
||||
self._capture_cudagraphs(
|
||||
compilation_cases=compilation_cases_decode,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.FULL,
|
||||
uniform_decode=True,
|
||||
batch_descriptors=batch_descs,
|
||||
cudagraph_runtime_mode=runtime_mode,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -4913,19 +4876,32 @@ class GPUModelRunner(
|
||||
|
||||
def _capture_cudagraphs(
|
||||
self,
|
||||
compilation_cases: list[tuple[int, bool]],
|
||||
batch_descriptors: list[BatchDescriptor],
|
||||
cudagraph_runtime_mode: CUDAGraphMode,
|
||||
uniform_decode: bool,
|
||||
):
|
||||
assert (
|
||||
cudagraph_runtime_mode != CUDAGraphMode.NONE
|
||||
and cudagraph_runtime_mode.valid_runtime_modes()
|
||||
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"
|
||||
|
||||
if not batch_descriptors:
|
||||
return
|
||||
|
||||
uniform_decode = batch_descriptors[0].uniform
|
||||
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
|
||||
dummy_run = functools.partial(
|
||||
self._dummy_run,
|
||||
uniform_decode=uniform_decode,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
force_attention=force_attention,
|
||||
)
|
||||
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if is_global_first_rank():
|
||||
compilation_cases = tqdm(
|
||||
compilation_cases,
|
||||
batch_descriptors = tqdm(
|
||||
batch_descriptors,
|
||||
disable=not self.load_config.use_tqdm_on_load,
|
||||
desc="Capturing CUDA graphs ({}, {})".format(
|
||||
"decode" if uniform_decode else "mixed prefill-decode",
|
||||
@@ -4934,7 +4910,10 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
# We skip EPLB here since we don't want to record dummy metrics
|
||||
for num_tokens, activate_lora in compilation_cases:
|
||||
for batch_desc in batch_descriptors:
|
||||
num_tokens = batch_desc.num_tokens
|
||||
activate_lora = batch_desc.has_lora
|
||||
|
||||
# We currently only capture ubatched graphs when its a FULL
|
||||
# cudagraph, a uniform decode batch, and the number of tokens
|
||||
# is above the threshold. Otherwise we just capture a non-ubatched
|
||||
@@ -4952,28 +4931,22 @@ class GPUModelRunner(
|
||||
|
||||
for _ in range(self.compilation_config.cudagraph_num_of_warmups):
|
||||
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
|
||||
# But be careful, warm up with `NONE`is orthogonal to
|
||||
# But be careful, warm up with `NONE` is orthogonal to
|
||||
# if we want to warm up attention or not. This is
|
||||
# different from the case where `FULL` implies capture
|
||||
# attention while `PIECEWISE` implies no attention.
|
||||
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
|
||||
self._dummy_run(
|
||||
dummy_run(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
force_attention=force_attention,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
)
|
||||
self._dummy_run(
|
||||
|
||||
# Capture run
|
||||
dummy_run(
|
||||
num_tokens,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
uniform_decode=uniform_decode,
|
||||
allow_microbatching=allow_microbatching,
|
||||
skip_eplb=True,
|
||||
remove_lora=False,
|
||||
activate_lora=activate_lora,
|
||||
is_graph_capturing=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user