[MRV2] Extensible CG dispatch rework (#35959)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2026-03-09 16:58:45 -04:00
committed by GitHub
parent 4e571ce643
commit 483463f735
9 changed files with 561 additions and 652 deletions

View File

@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
def __str__(self) -> str:
return self.name
def __bool__(self) -> bool:
return self != CUDAGraphMode.NONE
@config
class PassConfig:

View File

@@ -104,19 +104,24 @@ class BlockTables:
self.num_blocks.copy_to_uva()
def gather_block_tables(
self, idx_mapping: torch.Tensor
self,
idx_mapping: torch.Tensor,
num_reqs_padded: int,
) -> tuple[torch.Tensor, ...]:
num_reqs = idx_mapping.shape[0]
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
# Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
idx_mapping,
self.block_table_ptrs,
self.input_block_table_ptrs,
self.block_table_strides,
self.num_blocks.gpu,
self.num_blocks.gpu.stride(0),
num_reqs,
self.input_block_tables[0].shape[1], # max_num_blocks
BLOCK_SIZE=1024, # type: ignore
)
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
# NOTE(woosuk): The output may be used for CUDA graph capture.
@@ -130,6 +135,7 @@ class BlockTables:
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
num_tokens_padded: int,
) -> torch.Tensor:
num_reqs = idx_mapping.shape[0]
num_tokens = positions.shape[0]
@@ -151,7 +157,7 @@ class BlockTables:
PAD_ID=PAD_SLOT_ID,
TRITON_BLOCK_SIZE=1024, # type: ignore
)
return self.slot_mappings[:, :num_tokens]
return self.slot_mappings[:, :num_tokens_padded]
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
block_table_strides, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
num_reqs, # actual number of requests (for padding)
max_num_blocks, # stride for zeroing padded rows
BLOCK_SIZE: tl.constexpr,
):
# kv cache group id
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
stride = tl.load(block_table_strides + group_id)
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
if batch_idx >= num_reqs:
# Zero out padded rows.
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
return
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
stride = tl.load(block_table_strides + group_id)
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
src_row_ptr = src_block_table_ptr + req_idx * stride
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
for i in tl.range(0, num_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import torch
@@ -11,235 +13,260 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
logger = init_logger(__name__)
@dataclass(frozen=True)
class BatchExecutionDescriptor:
"""Describes the shape of the batch and CG mode to run; this is used to make shape
matches between the capture and runtime."""
cg_mode: CUDAGraphMode
num_tokens: int
num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs)
uniform_token_count: int | None = None
def _is_compatible(
desc: BatchExecutionDescriptor,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> bool:
# desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
# desc.num_reqs=None means no request padding needed (PIECEWISE)
return (
(
desc.uniform_token_count is None
or desc.uniform_token_count == uniform_token_count
)
and (desc.num_reqs is None or desc.num_reqs >= num_reqs)
and desc.num_tokens >= num_tokens
)
def get_uniform_token_count(
num_reqs: int,
num_tokens: int,
max_query_len: int,
) -> int | None:
"""
Return the uniform token count if batch is uniform, else None.
A batch is uniform if all requests have the same number of tokens.
"""
if (max_query_len == num_tokens // num_reqs) and (
num_tokens == max_query_len * num_reqs
):
return max_query_len
return None
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
use_aux_hidden_state_outputs: bool,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
decode_query_len: int,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_mode = cudagraph_mode
self.decode_query_len = decode_query_len
self.dp_size = vllm_config.parallel_config.data_parallel_size
use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
)
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self._graphs_captured = False
self._candidates: list[list[BatchExecutionDescriptor]] = []
self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
self._init_candidates()
def _init_candidates(self) -> None:
"""Build priority-ordered candidate lists for each token count."""
capture_sizes = self.compilation_config.cudagraph_capture_sizes
if not (self.cudagraph_mode and capture_sizes):
return
capture_sizes = sorted(capture_sizes)
max_decode_tokens = self.max_num_reqs * self.decode_query_len
decode_mode = self.cudagraph_mode.decode_mode()
mixed_mode = self.cudagraph_mode.mixed_mode()
separate_decode_routine = self.cudagraph_mode.separate_routine()
descs_by_token_count = defaultdict(list)
descs_by_mode = defaultdict(list)
for num_tokens in capture_sizes:
# Capture uniform decode specfifc graphs if required
# (i.e. separate decode routine)
if (
separate_decode_routine
and decode_mode
and self.decode_query_len <= num_tokens <= max_decode_tokens
):
desc = BatchExecutionDescriptor(
cg_mode=decode_mode,
num_tokens=num_tokens,
num_reqs=num_tokens // self.decode_query_len,
uniform_token_count=self.decode_query_len,
)
descs_by_mode[decode_mode].append(desc)
descs_by_token_count[num_tokens].append(desc)
if mixed_mode:
# for PIECEWISE graphs there is no limit on requests when replaying
# i.e. no request padding is needed
# so we leave it as None
num_reqs = (
min(num_tokens, self.max_num_reqs)
if mixed_mode == CUDAGraphMode.FULL
else None
)
desc = BatchExecutionDescriptor(
cg_mode=mixed_mode,
num_tokens=num_tokens,
num_reqs=num_reqs,
)
descs_by_mode[mixed_mode].append(desc)
descs_by_token_count[num_tokens].append(desc)
if not descs_by_token_count:
return
sorted_padded = sorted(descs_by_token_count.keys())
self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]
current_range_start = 0
for cg_size in sorted_padded:
for i in range(current_range_start, cg_size + 1):
self._candidates[i] = descs_by_token_count[cg_size]
current_range_start = cg_size + 1
for mode, descs in descs_by_mode.items():
descs.sort(key=lambda d: d.num_tokens, reverse=True)
self._capture_descs[mode] = descs
def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self, num_tokens: int, uniform_decode: bool = False
) -> int | None:
if uniform_decode and self.uniform_decode_cudagraph_sizes:
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
# select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
# NOTE: Values returned by `prepare_dummy_inputs` will override the
# default values above.
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
}
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
):
model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Allocate output buffers if not already done.
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
model_inputs=model_inputs,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with (
set_forward_context(
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
),
torch.cuda.graph(graph, self.pool),
):
model_output = model(**model_inputs)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
# Copy outputs to the output buffers.
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
if self.use_aux_hidden_state_outputs:
for i, aux_hidden in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux_hidden
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
model_inputs: dict[str, torch.Tensor | None],
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
model(**model_inputs)
return len(self._capture_descs) > 0
@torch.inference_mode()
def capture(
self,
create_forward_fn: Callable[
[BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
],
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
"""Capture CUDA graphs.
Args:
create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
returns a function that runs forward with a given CUDAGraphMode.
"""
with graph_capture(device=self.device):
# Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
# activations so FULL activations should fit in already allocated
# buffers in the graph pool.
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
if mode not in self._capture_descs:
continue
descs = self._capture_descs[mode]
if is_global_first_rank():
descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
for desc in descs:
# Prepare inputs and get forward function
forward_fn = create_forward_fn(desc)
# Warmup
forward_fn(CUDAGraphMode.NONE)
# Capture
logger.debug(
"CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
)
if desc.cg_mode == CUDAGraphMode.PIECEWISE:
forward_fn(CUDAGraphMode.PIECEWISE)
else:
assert desc not in self.graphs, (
f"Graph already captured for {desc}"
)
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
forward_fn(CUDAGraphMode.NONE)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader().join_after_forward()
self.graphs[desc] = graph
self._graphs_captured = True
def dispatch(
self,
num_reqs: int,
num_tokens: int,
uniform_token_count: int | None,
) -> BatchExecutionDescriptor:
"""Find matching cudagraph descriptor from priority-ordered candidates."""
if self._graphs_captured and 0 < num_tokens < len(self._candidates):
for desc in self._candidates[num_tokens]:
if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
return desc
return BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
)
def run_fullgraph(self, desc: BatchExecutionDescriptor):
"""Replay a captured FULL cudagraph."""
assert desc.cg_mode == CUDAGraphMode.FULL, (
f"Expected FULL mode, got {desc.cg_mode}"
)
assert desc in self.graphs, f"No cudagraph for {desc}"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[desc].replay()
class ModelCudaGraphManager(CudaGraphManager):
"""CudaGraphManager with model-specific capture and hidden state management."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
decode_query_len: int,
):
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
self.hidden_states: torch.Tensor | None = None
self.aux_hidden_states: list[torch.Tensor] = []
self.use_aux_hidden_state_outputs = False
def capture(
self,
model: nn.Module,
@@ -249,139 +276,81 @@ class CudaGraphManager:
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
use_aux_hidden_state_outputs: bool = False,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
"""Capture CUDA graphs for model forward pass."""
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> Callable[[CUDAGraphMode], None]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def forward_fn(cg_mode: CUDAGraphMode) -> None:
batch_descriptor = (
BatchDescriptor(num_tokens=num_tokens)
if cg_mode == CUDAGraphMode.PIECEWISE
else None
)
with set_forward_context(
attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=cg_mode,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
model_inputs = {
"input_ids": input_buffers.input_ids[:num_tokens],
"positions": input_buffers.positions[:num_tokens],
}
model_output = model(**model_inputs)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = []
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
self.aux_hidden_states = [
torch.empty_like(x) for x in aux_hidden_states
]
self.hidden_states[:num_tokens] = hidden_states
for i, aux in enumerate(aux_hidden_states):
self.aux_hidden_states[i][:num_tokens] = aux
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
return forward_fn
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
super().capture(create_forward_fn, progress_bar_desc)
def run_fullgraph(
self, num_tokens: int
self, desc: BatchExecutionDescriptor
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
"""Replay a captured FULL cudagraph and return hidden states."""
super().run_fullgraph(desc)
assert self.hidden_states is not None
hidden_states = self.hidden_states[:num_tokens]
hidden_states = self.hidden_states[: desc.num_tokens]
if not self.use_aux_hidden_state_outputs:
return hidden_states
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
def get_cudagraph_sizes(
capture_sizes: list[int] | None,
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}, {}
capture_sizes = sorted(capture_sizes)
if not capture_sizes:
return {}, {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]
def prepare_inputs_to_capture(

View File

@@ -1,9 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import torch
import torch.distributed as dist
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import get_dp_group
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
CudaGraphManager,
)
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp(
def sync_cudagraph_and_dp_padding(
cudagraph_manager: CudaGraphManager,
desired_batch_desc: BatchExecutionDescriptor,
num_tokens: int,
cudagraph_size: int,
cudagraph_runtime_mode: int,
num_reqs: int,
uniform_token_count: int | None,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
"""
Coordinates the batch descriptor and DP padding across all ranks.
Returns (synced_batch_desc, num_tokens_across_dp).
"""
assert dp_size > 1, "DP size must be greater than 1"
group = get_dp_group().cpu_group
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1], tensor[2]
num_tokens_across_dp = tensor[0]
cg_mode_across_dp = tensor[1]
uniform_token_counts_across_dp = tensor[2]
def get_cudagraph_and_dp_padding(
num_tokens: int,
cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return cudagraph_size, None, cudagraph_runtime_mode
else:
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return 0, None, 0
synced_desc = BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0
)
return synced_desc, None
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item()))
if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
# If any rank wants to run eager, all ranks run eager
if synced_cg_mode == CUDAGraphMode.NONE:
return BatchExecutionDescriptor(
cg_mode=CUDAGraphMode.NONE,
num_tokens=num_tokens,
num_reqs=num_reqs,
), num_tokens_across_dp
synced_num_tokens = int(num_tokens_across_dp.max().item())
synced_uniform_token_count = uniform_token_counts_across_dp[0]
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
if synced_uniform_token_count == 0 or not torch.all(
uniform_token_counts_across_dp == synced_uniform_token_count
):
synced_uniform_token_count = None
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
# so we don't perform request padding for PIECEWISE graphs
synced_desc = cudagraph_manager.dispatch(
num_reqs, synced_num_tokens, synced_uniform_token_count
)
# Update num_tokens_across_dp to reflect padded size.
num_tokens_across_dp[:] = synced_desc.num_tokens
return synced_desc, num_tokens_across_dp

View File

@@ -37,6 +37,7 @@ class InputBatch:
# batch_idx -> req_id
req_ids: list[str]
num_reqs: int
num_reqs_after_padding: int
# batch_idx -> req_state_idx
idx_mapping: torch.Tensor
@@ -123,6 +124,7 @@ class InputBatch:
return cls(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs_after_padding=num_reqs,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
cu_num_logits: torch.Tensor,
num_logits: int,
) -> torch.Tensor:
num_reqs = seq_lens.shape[0]
# use idx_mapping.shape[0] for actual request count
num_reqs = idx_mapping.shape[0]
num_speculative_steps = draft_tokens.shape[-1]
logits_indices = torch.empty(

View File

@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.cudagraph_utils import (
BatchExecutionDescriptor,
ModelCudaGraphManager,
get_uniform_token_count,
)
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config,
self.use_aux_hidden_state_outputs,
self.device,
self.compilation_config.cudagraph_mode,
decode_query_len=self.decode_query_len,
)
# Structured outputs worker.
self.structured_outputs_worker = StructuredOutputsWorker(
@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
**kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
if uniform_decode:
# Align tokens to uniform_decode_query_len for cudagraph
# compatibility across DP ranks.
query_len = self.cudagraph_manager.uniform_decode_query_len
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
num_tokens = num_reqs * query_len
num_tokens_per_request = [query_len] * num_reqs
else:
num_reqs = min(num_tokens, self.max_num_reqs)
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
# HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
# and for spec-decode with MTP we want to make sure the dummy runs use
# 1+num_speculative_tokens we use max here, this will likely be eventually
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens = max(num_tokens, self.decode_query_len)
num_reqs = num_tokens // self.decode_query_len
assert num_tokens % self.decode_query_len == 0
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with self.maybe_setup_dummy_loras(self.lora_config):
self.cudagraph_manager.capture(
model=self.model,
model_state=self.model_state,
input_buffers=self.input_buffers,
block_tables=self.block_tables,
attn_groups=self.attn_groups,
kv_cache_config=self.kv_cache_config,
self.model,
self.model_state,
self.input_buffers,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
)
if self.speculator is not None:
self.speculator.capture_model()
@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
def prepare_inputs(
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = batch_desc.num_tokens
assert num_tokens > 0
num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req)
@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded = batch_desc.num_reqs or num_reqs
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs]
seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
dcp_local_seq_lens = None
if self.use_dcp:
@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs_after_padding=num_reqs_padded,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
# Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping,
num_reqs_padded=input_batch.num_reqs_after_padding,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
num_tokens_padded=input_batch.num_tokens_after_padding,
)
return block_tables, slot_mappings
@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
# Get local cudagraph mode and size.
local_cudagraph_mode, local_cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(
num_reqs=len(scheduler_output.num_scheduled_tokens),
num_tokens=scheduler_output.total_num_scheduled_tokens,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
)
# Get batch descriptor and sync across DP ranks.
num_reqs = len(scheduler_output.num_scheduled_tokens)
num_toks = scheduler_output.total_num_scheduled_tokens
max_query_len = max(scheduler_output.num_scheduled_tokens.values())
uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens,
local_cudagraph_size,
local_cudagraph_mode.value,
self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank,
)
batch_desc = self.cudagraph_manager.dispatch(
num_reqs, num_toks, uniform_tok_count
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if num_tokens_after_padding == 0:
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_toks,
num_reqs,
uniform_tok_count,
self.dp_size,
self.dp_rank,
)
if batch_desc.num_tokens == 0:
# All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding
)
input_batch = self.prepare_inputs(scheduler_output, batch_desc)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs, num_tokens_after_padding, self.input_buffers
batch_desc.num_reqs or num_reqs,
batch_desc.num_tokens,
self.input_buffers,
)
if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs["intermediate_tensors"] = intermediate_tensors
# Run model.
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding
)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode,
cudagraph_runtime_mode=batch_desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer,

View File

@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn.
num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,

View File

@@ -1,181 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.model_executor.offloader.base import get_offloader
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cudagraph_utils import (
capture_graphs,
get_cudagraph_sizes,
BatchExecutionDescriptor,
CudaGraphManager,
prepare_inputs_to_capture,
)
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup
class EagleCudaGraphManager:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
class EagleCudaGraphManager(CudaGraphManager):
"""CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
cudagraph_mode: CUDAGraphMode,
draft_tokens: torch.Tensor,
):
assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
"EagleCudaGraphManager does not support PIECEWISE mode yet"
)
# Eagle always uses uniform decode with query_len=1
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
self.draft_tokens = draft_tokens
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
# Use a dedicated pool for Eagle to avoid memory overlap with the main
# model's cudagraph. The base class uses a shared global pool, but Eagle's
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
# the main model's allocations when sharing the same pool.
if cudagraph_mode:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)
def get_cudagraph_runtime_mode(
self, num_tokens: int
) -> tuple[CUDAGraphMode, int | None]:
cudagraph_size = self.get_cudagraph_size(num_tokens)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
else:
cudagraph_mode = self.cudagraph_mode
if (
cudagraph_mode == CUDAGraphMode.FULL
and cudagraph_size is not None
and cudagraph_size not in self.graphs
):
# If graph wasn't captured yet, fall back to eager.
# This might happen when the dummy run is called before capture.
cudagraph_mode = CUDAGraphMode.NONE
cudagraph_size = None
return cudagraph_mode, cudagraph_size
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
with torch.cuda.graph(graph, self.pool):
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Join offloader's copy stream after forward to avoid unjoined
# stream error. The last layer's start_prefetch forks copy_stream,
# but wait_prefetch only happens in the next forward pass.
get_offloader().join_after_forward()
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
generate_fn: Callable,
@@ -184,31 +50,42 @@ class EagleCudaGraphManager:
block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
progress_bar_desc: str = "Capturing CUDA graphs",
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers,
block_tables=block_tables,
attn_groups=attn_groups,
kv_cache_config=kv_cache_config,
)
def create_forward_fn(
desc: BatchExecutionDescriptor,
) -> Callable[[CUDAGraphMode], None]:
num_tokens = desc.num_tokens
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
num_tokens_across_dp = (
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if self.dp_size > 1
else None
)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
num_tokens,
model_state,
input_buffers,
block_tables,
attn_groups,
kv_cache_config,
)
def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
# Sync offloader before replay - needed when transitioning from
# eager/piecewise to full cudagraph (e.g., prefill → decode).
# The previous eager iteration's start_prefetch may have queued
# H2D copies on copy_stream that the graph's captured events
# cannot see. Without this, replay could overwrite static buffers
# while those copies are still in flight.
get_offloader().sync_prev_onload()
self.graphs[num_tokens].replay()
return lambda cg_mode: generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cg_mode,
)
super().capture(create_forward_fn, progress_bar_desc)
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
"""Replay a captured FULL cudagraph and return draft tokens."""
super().run_fullgraph(desc)
return self.draft_tokens

View File

@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
@@ -75,7 +75,16 @@ class EagleSpeculator:
device=device,
)
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
# currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
else:
cudagraph_mode = CUDAGraphMode.NONE
self.cudagraph_manager = EagleCudaGraphManager(
vllm_config, device, cudagraph_mode, self.draft_tokens
)
def load_model(self, target_model: nn.Module) -> None:
self.model = load_eagle_model(target_model, self.vllm_config)
@@ -171,7 +180,7 @@ class EagleSpeculator:
)
if attn_metadata is not None:
self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
idx_mapping, query_start_loc, pos, num_tokens_padded
)
def capture_model(self) -> None:
@@ -185,6 +194,7 @@ class EagleSpeculator:
self.block_tables,
self.attn_groups,
self.kv_cache_config,
progress_bar_desc="Capturing eagle CUDA graphs",
)
@torch.inference_mode()
@@ -251,6 +261,7 @@ class EagleSpeculator:
logits = self.model.compute_logits(sample_hidden_states)
num_reqs = input_batch.num_reqs
num_reqs_padded = input_batch.num_reqs_after_padding
# NOTE(woosuk): For draft sampling, we only consider the temperature
# and ignore the other sampling parameters such as top_k and top_p,
# for simplicity and performance.
@@ -292,48 +303,52 @@ class EagleSpeculator:
self.max_num_reqs,
)
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos
)
# Get batch descriptor and sync across DP ranks.
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
cudagraph_mode, cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
)
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
num_tokens_across_dp = None
if self.dp_size > 1:
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
self.cudagraph_manager,
batch_desc,
num_reqs,
cudagraph_size,
cudagraph_mode.value,
num_reqs,
1, # uniform_token_count
self.dp_size,
self.dp_rank,
)
)
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
if cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
return self.draft_tokens[:num_reqs]
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
)
if batch_desc.cg_mode == CUDAGraphMode.FULL:
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
# Run eager or piecewise CUDA graph.
attn_metadata_updated = None
slot_mappings_updated = None
if not (dummy_run and skip_attn_for_dummy_run):
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
)
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
block_tables = [
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
]
# FIXME(woosuk): This is UNSAFE!!
attn_metadata_updated = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_reqs,
num_reqs=num_reqs_padded,
num_tokens=num_reqs_padded,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=1,
seq_lens=self.input_buffers.seq_lens[:num_reqs],
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
@@ -345,11 +360,11 @@ class EagleSpeculator:
self.generate_draft(
num_reqs,
num_tokens_padded,
batch_desc.num_tokens,
attn_metadata_updated,
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_mode,
cudagraph_runtime_mode=batch_desc.cg_mode,
)
return self.draft_tokens[:num_reqs]