[Model Runner V2] support piecewise & mixed cudagraph (#32771)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
zhrrr
2026-02-19 07:03:17 +08:00
committed by GitHub
parent 40da9625a1
commit 11d3976b88
5 changed files with 343 additions and 108 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from collections.abc import Callable
from typing import Any
import numpy as np
@@ -11,7 +11,8 @@ from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
@@ -34,14 +35,27 @@ class CudaGraphManager:
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(
use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
@@ -54,20 +68,16 @@ class CudaGraphManager:
return len(self.cudagraph_sizes) > 0
def get_cudagraph_size(
self,
num_tokens_after_padding: int,
num_tokens_per_request: Iterable[int],
self, num_tokens: int, uniform_decode: bool = False
) -> int | None:
return get_cudagraph_size(
num_tokens_after_padding,
num_tokens_per_request,
self.cudagraph_sizes,
self.cudagraph_mode,
)
if uniform_decode and self.uniform_decode_cudagraph_sizes:
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
return self.cudagraph_sizes.get(num_tokens)
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
@@ -75,8 +85,25 @@ class CudaGraphManager:
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
# select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
@@ -92,6 +119,9 @@ class CudaGraphManager:
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
@@ -112,13 +142,40 @@ class CudaGraphManager:
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)
def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
set_forward_context(
attn_metadata,
self.vllm_config,
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
@@ -131,9 +188,44 @@ class CudaGraphManager:
positions=positions,
inputs_embeds=inputs_embeds,
)
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
@torch.inference_mode()
def capture(
self,
@@ -144,11 +236,11 @@ class CudaGraphManager:
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
@@ -156,10 +248,50 @@ class CudaGraphManager:
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)
def run(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)
def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
return cudagraph_mode, cudagraph_size
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:num_tokens]
@@ -170,22 +302,18 @@ def get_cudagraph_sizes(
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}
return {}, {}
capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}
return {}, {}
cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
@@ -193,45 +321,34 @@ def get_cudagraph_sizes(
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes
def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
if not cudagraph_mode.has_full_cudagraphs():
# No full CUDA graph is used.
return None
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None
is_mixed = any(x > 1 for x in num_tokens_per_request)
if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL:
# Prefill is included, and this mode doesn't use CUDA graph for it.
return None
return size
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes
def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
def prepare_inputs_to_capture(
@@ -242,8 +359,12 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
num_tokens_per_req = num_tokens // num_reqs
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens

View File

@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
def get_batch_metadata_across_dp(
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens: int,
cudagraph_size: int,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dp_size > 1
# Use CPU group to avoid CPU-GPU synchronization.
group = get_dp_group().cpu_group
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
tensor[0][dp_rank] = num_tokens
tensor[1][dp_rank] = cudagraph_size
tensor[2][dp_rank] = cudagraph_runtime_mode
dist.all_reduce(tensor, group=group)
return tensor[0], tensor[1]
return tensor[0], tensor[1], tensor[2]
def get_cudagraph_and_dp_padding(
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
) -> tuple[bool, int, torch.Tensor | None]:
num_tokens: int,
cudagraph_size: int | None,
cudagraph_runtime_mode: int,
dp_size: int,
dp_rank: int,
) -> tuple[int, torch.Tensor | None, int]:
if dp_size == 1:
if cudagraph_size is not None:
return True, cudagraph_size, None
return cudagraph_size, None, cudagraph_runtime_mode
else:
return False, num_tokens, None
return num_tokens, None, cudagraph_runtime_mode
# Convert None to -1 for sync (indicates no cudagraph available)
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
num_tokens, cudagraph_size, dp_size, dp_rank
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
get_batch_metadata_across_dp(
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
)
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return False, 0, None
return 0, None, 0
if torch.all(cudagraph_size_across_dp != -1).item():
# All ranks use CUDA graph or have zero tokens.
# Use CUDA graph for all ranks.
# Pad all ranks to the maximum CUDA graph size.
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
# Check if all ranks have valid cudagraph_size.
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
if synced_cudagraph_mode != 0 and all_have_cudagraph:
# All ranks use cudagraph. Pad to max cudagraph_size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return True, max_cudagraph_size, num_tokens_across_dp
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
else:
# Some ranks do not use CUDA graph. Use eager mode for all ranks.
# No padding is needed except for ranks that have no tokens to run.
# Fall back to eager mode (no cudagraph).
# Either some rank doesn't have cudagraph size or mode is NONE.
synced_cudagraph_mode = 0
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return False, num_tokens_after_padding, num_tokens_across_dp
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode

View File

@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import (
get_pp_group,
prepare_communication_buffer_for_model,
)
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.do_spec_decode = False
self.num_speculative_steps = 0
self.speculator = None
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_tables=self.block_tables,
attn_metadata_builders=self.attn_metadata_builders,
kv_cache_config=self.kv_cache_config,
has_lora=self.lora_config is not None,
)
if self.do_spec_decode:
self.speculator.capture_model()
@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
# Get the CUDA graph size. None means no CUDA graph is used.
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens.values(),
# Get local cudagraph mode and size.
local_cudagraph_mode, local_cudagraph_size = (
self.cudagraph_manager.get_cudagraph_runtime_mode(
num_reqs=len(scheduler_output.num_scheduled_tokens),
num_tokens=scheduler_output.total_num_scheduled_tokens,
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
)
)
use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = (
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens,
cudagraph_size,
local_cudagraph_size,
local_cudagraph_mode.value,
self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank,
)
)
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
if num_tokens_after_padding == 0:
# All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output)
@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# FIXME(woosuk): Fix warmup for LoRA.
# Run model.
if use_cudagraph:
# Run CUDA graph.
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.cudagraph_manager.run(
hidden_states = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding
)
else:
# Run PyTorch model in eager mode.
# For piecewise and eager mode, just call model().
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds = None
assert intermediate_tensors is not None
batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None,
)
with set_forward_context(
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
# TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)

View File

@@ -7,7 +7,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
@@ -103,14 +103,17 @@ class EagleSpeculator:
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> tuple[torch.Tensor, torch.Tensor]:
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
batch_descriptor=batch_descriptor,
):
ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens],
@@ -127,9 +130,11 @@ class EagleSpeculator:
def generate_draft(
self,
num_reqs: int,
num_tokens_padded: int,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
@@ -137,8 +142,14 @@ class EagleSpeculator:
for step in range(1, self.num_speculative_steps):
# Run the eagle model.
last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp
num_tokens_padded,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
)
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
@@ -283,12 +294,14 @@ class EagleSpeculator:
)
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
if cudagraph_size is not None:
# Run CUDA graph.
self.cudagraph_manager.run(cudagraph_size)
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
# Run full CUDA graph.
self.cudagraph_manager.run_fullgraph(cudagraph_size)
return self.draft_tokens[:num_reqs]
# Run eager mode.
# Run eager or piecewise CUDA graph.
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
query_start_loc_cpu = torch.arange(
num_reqs + 1, dtype=torch.int32, device="cpu"
)
@@ -312,8 +325,13 @@ class EagleSpeculator:
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None
) # FIXME
num_reqs,
num_tokens_padded,
attn_metadata,
slot_mappings_by_layer,
num_tokens_across_dp=None, # FIXME
cudagraph_runtime_mode=cudagraph_mode,
)
return self.draft_tokens[:num_reqs]

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
@@ -31,16 +32,17 @@ class EagleCudaGraphManager:
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.cudagraph_mode == CUDAGraphMode.FULL:
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
self.cudagraph_sizes = get_cudagraph_sizes(
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
_, self.cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
uniform_decode_query_len=1,
uniform_decode_cudagraph=True,
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
@@ -54,12 +56,21 @@ class EagleCudaGraphManager:
def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
generate_fn: Callable,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs,
@@ -69,19 +80,70 @@ class EagleCudaGraphManager:
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=1,
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up.
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
# Capture the graph.
capture_fn(
num_reqs=num_reqs,
num_tokens=num_tokens,
generate_fn=generate_fn,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
num_tokens_across_dp=num_tokens_across_dp,
)
def _capture_full_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.NONE,
)
self.graphs[num_tokens] = graph
def _capture_piecewise_graph(
self,
num_reqs: int,
num_tokens: int,
generate_fn: Callable,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor,
) -> None:
generate_fn(
num_reqs,
num_tokens,
attn_metadata,
slot_mappings,
num_tokens_across_dp,
CUDAGraphMode.PIECEWISE,
)
@torch.inference_mode()
def capture(
self,
@@ -91,10 +153,15 @@ class EagleCudaGraphManager:
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
if self.cudagraph_mode == CUDAGraphMode.NONE:
return
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn,
input_buffers=input_buffers,
block_tables=block_tables,
@@ -102,6 +169,6 @@ class EagleCudaGraphManager:
kv_cache_config=kv_cache_config,
)
def run(self, num_tokens: int) -> None:
def run_fullgraph(self, num_tokens: int) -> None:
assert num_tokens in self.graphs
self.graphs[num_tokens].replay()