diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index 0c5a93abc..66da081b4 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable, Iterable +from collections.abc import Callable from typing import Any import numpy as np @@ -11,7 +11,8 @@ from tqdm import tqdm from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( @@ -34,14 +35,27 @@ class CudaGraphManager: self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.dp_size = vllm_config.parallel_config.data_parallel_size + + self.uniform_decode_query_len = 1 + spec_config = vllm_config.speculative_config + if spec_config is not None: + self.uniform_decode_query_len += spec_config.num_speculative_tokens + self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None self.cudagraph_mode = self.compilation_config.cudagraph_mode - self.cudagraph_sizes = get_cudagraph_sizes( + + use_uniform_decode_cudagraph = ( + self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and self.cudagraph_mode.separate_routine() + ) + self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes( self.compilation_config.cudagraph_capture_sizes, self.max_num_reqs, self.max_num_tokens, self.cudagraph_mode, + self.uniform_decode_query_len, + use_uniform_decode_cudagraph, ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} @@ -54,20 +68,16 @@ class CudaGraphManager: return len(self.cudagraph_sizes) > 0 def get_cudagraph_size( - self, - num_tokens_after_padding: int, - num_tokens_per_request: Iterable[int], + self, num_tokens: int, uniform_decode: bool = False ) -> int | None: - return get_cudagraph_size( - num_tokens_after_padding, - num_tokens_per_request, - self.cudagraph_sizes, - self.cudagraph_mode, - ) + if uniform_decode and self.uniform_decode_cudagraph_sizes: + return self.uniform_decode_cudagraph_sizes.get(num_tokens) + return self.cudagraph_sizes.get(num_tokens) def capture_graph( self, num_tokens: int, + capture_cg_mode: CUDAGraphMode, model: nn.Module, input_buffers: InputBuffers, mrope_positions: torch.Tensor | None, @@ -75,8 +85,25 @@ class CudaGraphManager: block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, + has_lora: bool = False, + uniform_decode: bool = False, ) -> None: - num_reqs = min(num_tokens, self.max_num_reqs) + # select and check capture function + assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( + f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}" + ) + if capture_cg_mode == CUDAGraphMode.PIECEWISE: + capture_fn = self._capture_piecewise_graph + else: + capture_fn = self._capture_full_graph + # prepare inputs + if uniform_decode: + num_reqs = min( + cdiv(num_tokens, self.uniform_decode_query_len), + self.max_num_reqs, + ) + else: + num_reqs = min(num_tokens, self.max_num_reqs) input_ids = input_buffers.input_ids[:num_tokens] positions = input_buffers.positions[:num_tokens] if self.uses_mrope: @@ -92,6 +119,9 @@ class CudaGraphManager: attn_metadata_builders, self.max_model_len, kv_cache_config, + uniform_decode_query_len=( + self.uniform_decode_query_len if uniform_decode else 0 + ), ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) @@ -112,13 +142,40 @@ class CudaGraphManager: if self.hidden_states is None: self.hidden_states = torch.empty_like(hidden_states) + capture_fn( + num_tokens=num_tokens, + num_reqs=num_reqs, + model=model, + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + num_tokens_across_dp=num_tokens_across_dp, + attn_metadata=attn_metadata, + slot_mappings=slot_mappings, + has_lora=has_lora, + ) + + def _capture_full_graph( + self, + num_tokens: int, + num_reqs: int, + model: nn.Module, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None, + num_tokens_across_dp: torch.Tensor, + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, + has_lora: bool = False, + ) -> None: + assert attn_metadata is not None # Capture the graph. assert num_tokens not in self.graphs graph = torch.cuda.CUDAGraph() with ( set_forward_context( - attn_metadata, - self.vllm_config, + attn_metadata=attn_metadata, + vllm_config=self.vllm_config, num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, @@ -131,9 +188,44 @@ class CudaGraphManager: positions=positions, inputs_embeds=inputs_embeds, ) + assert self.hidden_states is not None self.hidden_states[:num_tokens] = hidden_states self.graphs[num_tokens] = graph + def _capture_piecewise_graph( + self, + num_tokens: int, + num_reqs: int, + model: nn.Module, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None, + num_tokens_across_dp: torch.Tensor, + attn_metadata: dict[str, Any] | None, + slot_mappings: dict[str, torch.Tensor] | None, + has_lora: bool = False, + ) -> None: + # create batch descriptor for piecewise cudagraph dispatch key + batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora) + + # Capture run - CUDAGraphWrapper inside torch.compile will auto capture. + with set_forward_context( + attn_metadata=None, # piecewise no need attn_metadata + vllm_config=self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, + num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=batch_descriptor, + slot_mapping=slot_mappings, + ): + hidden_states = model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + assert self.hidden_states is not None + self.hidden_states[:num_tokens] = hidden_states + @torch.inference_mode() def capture( self, @@ -144,11 +236,11 @@ class CudaGraphManager: block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, + has_lora: bool = False, ) -> None: - capture_graphs( - self.cudagraph_sizes, - self.device, - self.capture_graph, + common_kwargs = dict( + device=self.device, + capture_fn=self.capture_graph, model=model, input_buffers=input_buffers, mrope_positions=mrope_positions, @@ -156,10 +248,50 @@ class CudaGraphManager: block_tables=block_tables, attn_metadata_builders=attn_metadata_builders, kv_cache_config=kv_cache_config, + has_lora=has_lora, ) - def run(self, num_tokens: int) -> torch.Tensor: - assert num_tokens in self.graphs + # Phase 1: Capture for mixed prefill-decode batches if needed. + mixed_mode = self.cudagraph_mode.mixed_mode() + if mixed_mode != CUDAGraphMode.NONE: + capture_graphs( + cudagraph_sizes=self.cudagraph_sizes, + capture_cudagraph_mode=mixed_mode, + desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})", + uniform_decode=False, + **common_kwargs, + ) + + # Phase 2: Capture FULL graphs for uniform decode batches if needed. + # This is only needed if we use a separate routine for decode batches + # and the decode_mode is FULL. + if self.uniform_decode_cudagraph_sizes: + capture_graphs( + cudagraph_sizes=self.uniform_decode_cudagraph_sizes, + capture_cudagraph_mode=CUDAGraphMode.FULL, + desc="Capturing CUDA graphs (decode, FULL)", + uniform_decode=True, + **common_kwargs, + ) + + def get_cudagraph_runtime_mode( + self, num_reqs: int, num_tokens: int, max_query_len: int + ) -> tuple[CUDAGraphMode, int | None]: + is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( + num_tokens == max_query_len * num_reqs + ) + + cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode) + if cudagraph_size is None: + cudagraph_mode = CUDAGraphMode.NONE + elif is_uniform_decode: + cudagraph_mode = self.cudagraph_mode.decode_mode() + else: + cudagraph_mode = self.cudagraph_mode.mixed_mode() + return cudagraph_mode, cudagraph_size + + def run_fullgraph(self, num_tokens: int) -> torch.Tensor: + assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens" self.graphs[num_tokens].replay() assert self.hidden_states is not None return self.hidden_states[:num_tokens] @@ -170,22 +302,18 @@ def get_cudagraph_sizes( max_num_reqs: int, max_num_tokens: int, cudagraph_mode: CUDAGraphMode, -) -> dict[int, int]: - if not cudagraph_mode.has_full_cudagraphs(): - return {} + uniform_decode_query_len: int = 1, + uniform_decode_cudagraph: bool = False, +) -> tuple[dict[int, int], dict[int, int]]: + # Support both FULL and PIECEWISE cudagraph modes + if cudagraph_mode == CUDAGraphMode.NONE: + return {}, {} if not capture_sizes: - return {} + return {}, {} capture_sizes = sorted(capture_sizes) - # Limit the capture sizes to the max number of requests or tokens. - upper_bound = ( - max_num_reqs - if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY - else max_num_tokens - ) - capture_sizes = [x for x in capture_sizes if x <= upper_bound] if not capture_sizes: - return {} + return {}, {} cudagraph_sizes: dict[int, int] = {} for i in range(1, capture_sizes[-1] + 1): @@ -193,45 +321,34 @@ def get_cudagraph_sizes( if i <= x: cudagraph_sizes[i] = x break - return cudagraph_sizes - -def get_cudagraph_size( - num_tokens_after_dp_padding: int, - num_tokens_per_request: Iterable[int], - cudagraph_sizes: dict[int, int], - cudagraph_mode: CUDAGraphMode, -) -> int | None: - if not cudagraph_mode.has_full_cudagraphs(): - # No full CUDA graph is used. - return None - - size = cudagraph_sizes.get(num_tokens_after_dp_padding) - if size is None: - # No CUDA graph for this size. - return None - - is_mixed = any(x > 1 for x in num_tokens_per_request) - if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL: - # Prefill is included, and this mode doesn't use CUDA graph for it. - return None - return size + uniform_decode_cudagraph_sizes: dict[int, int] = {} + if uniform_decode_cudagraph: + max_num_tokens = max_num_reqs * uniform_decode_query_len + uniform_decode_cudagraph_sizes = { + k: v + for k, v in cudagraph_sizes.items() + if v <= max_num_tokens and v >= uniform_decode_query_len + } + return cudagraph_sizes, uniform_decode_cudagraph_sizes def capture_graphs( cudagraph_sizes: dict[int, int], device: torch.device, capture_fn: Callable, + capture_cudagraph_mode: CUDAGraphMode, + desc: str = "Capturing CUDA graphs", **capture_kwargs, ) -> None: # Capture larger graphs first. sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True) if is_global_first_rank(): - sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs") + sizes_to_capture = tqdm(sizes_to_capture, desc=desc) with graph_capture(device=device): for size in sizes_to_capture: - capture_fn(size, **capture_kwargs) + capture_fn(size, capture_cudagraph_mode, **capture_kwargs) def prepare_inputs_to_capture( @@ -242,8 +359,12 @@ def prepare_inputs_to_capture( attn_metadata_builders: list[AttentionMetadataBuilder], max_model_len: int, kv_cache_config: KVCacheConfig, + uniform_decode_query_len: int = 0, ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: - num_tokens_per_req = num_tokens // num_reqs + if uniform_decode_query_len > 0: + num_tokens_per_req = uniform_decode_query_len + else: + num_tokens_per_req = num_tokens // num_reqs query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req query_start_loc_np[-1] = num_tokens diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py index 9794d3af0..724a6c39f 100644 --- a/vllm/v1/worker/gpu/dp_utils.py +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N def get_batch_metadata_across_dp( - num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int -) -> tuple[torch.Tensor, torch.Tensor]: + num_tokens: int, + cudagraph_size: int, + cudagraph_runtime_mode: int, + dp_size: int, + dp_rank: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert dp_size > 1 # Use CPU group to avoid CPU-GPU synchronization. group = get_dp_group().cpu_group - tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu") + tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor[0][dp_rank] = num_tokens tensor[1][dp_rank] = cudagraph_size + tensor[2][dp_rank] = cudagraph_runtime_mode dist.all_reduce(tensor, group=group) - return tensor[0], tensor[1] + return tensor[0], tensor[1], tensor[2] def get_cudagraph_and_dp_padding( - num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int -) -> tuple[bool, int, torch.Tensor | None]: + num_tokens: int, + cudagraph_size: int | None, + cudagraph_runtime_mode: int, + dp_size: int, + dp_rank: int, +) -> tuple[int, torch.Tensor | None, int]: if dp_size == 1: if cudagraph_size is not None: - return True, cudagraph_size, None + return cudagraph_size, None, cudagraph_runtime_mode else: - return False, num_tokens, None + return num_tokens, None, cudagraph_runtime_mode + # Convert None to -1 for sync (indicates no cudagraph available) if num_tokens == 0: cudagraph_size = 0 elif cudagraph_size is None: cudagraph_size = -1 - num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp( - num_tokens, cudagraph_size, dp_size, dp_rank + + num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = ( + get_batch_metadata_across_dp( + num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank + ) ) if torch.all(num_tokens_across_dp == 0).item(): # All ranks have zero tokens to run. - return False, 0, None + return 0, None, 0 - if torch.all(cudagraph_size_across_dp != -1).item(): - # All ranks use CUDA graph or have zero tokens. - # Use CUDA graph for all ranks. - # Pad all ranks to the maximum CUDA graph size. + # Synchronize cudagraph_runtime_mode across ranks by taking the minimum. + synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item()) + # Check if all ranks have valid cudagraph_size. + all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item() + + if synced_cudagraph_mode != 0 and all_have_cudagraph: + # All ranks use cudagraph. Pad to max cudagraph_size. max_cudagraph_size = int(cudagraph_size_across_dp.max().item()) num_tokens_across_dp[:] = max_cudagraph_size - return True, max_cudagraph_size, num_tokens_across_dp + return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode else: - # Some ranks do not use CUDA graph. Use eager mode for all ranks. - # No padding is needed except for ranks that have no tokens to run. + # Fall back to eager mode (no cudagraph). + # Either some rank doesn't have cudagraph size or mode is NONE. + synced_cudagraph_mode = 0 num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1) num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item()) - return False, num_tokens_after_padding, num_tokens_across_dp + return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index be620b0cc..cbae001c2 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, prepare_communication_buffer_for_model, ) -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model_loader from vllm.multimodal import MULTIMODAL_REGISTRY @@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.do_spec_decode = False self.num_speculative_steps = 0 self.speculator = None - self.req_states = RequestState( max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, @@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): block_tables=self.block_tables, attn_metadata_builders=self.attn_metadata_builders, kv_cache_config=self.kv_cache_config, + has_lora=self.lora_config is not None, ) if self.do_spec_decode: self.speculator.capture_model() @@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin): empty_output = self.kv_connector.no_forward(scheduler_output) return empty_output - # Get the CUDA graph size. None means no CUDA graph is used. - cudagraph_size = self.cudagraph_manager.get_cudagraph_size( - scheduler_output.total_num_scheduled_tokens, - scheduler_output.num_scheduled_tokens.values(), + # Get local cudagraph mode and size. + local_cudagraph_mode, local_cudagraph_size = ( + self.cudagraph_manager.get_cudagraph_runtime_mode( + num_reqs=len(scheduler_output.num_scheduled_tokens), + num_tokens=scheduler_output.total_num_scheduled_tokens, + max_query_len=max(scheduler_output.num_scheduled_tokens.values()), + ) ) - use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = ( + + # DP sync: num_tokens + cudagraph_size + cudagraph_mode + num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = ( get_cudagraph_and_dp_padding( scheduler_output.total_num_scheduled_tokens, - cudagraph_size, + local_cudagraph_size, + local_cudagraph_mode.value, self.parallel_config.data_parallel_size, self.parallel_config.data_parallel_rank, ) ) + cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode) if num_tokens_after_padding == 0: # All DP ranks have zero tokens to run. empty_output = self.kv_connector.no_forward(scheduler_output) @@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): # FIXME(woosuk): Fix warmup for LoRA. # Run model. - if use_cudagraph: - # Run CUDA graph. + if cudagraph_runtime_mode == CUDAGraphMode.FULL: + # Use explicit cudagraph replay for FULL mode. # NOTE(woosuk): Here, we don't need to pass the input tensors, # because they are already copied to the CUDA graph input buffers. self.kv_connector.pre_forward(scheduler_output) - hidden_states = self.cudagraph_manager.run( + hidden_states = self.cudagraph_manager.run_fullgraph( input_batch.num_tokens_after_padding ) else: - # Run PyTorch model in eager mode. + # For piecewise and eager mode, just call model(). positions = input_batch.positions if self.uses_mrope: assert input_batch.mrope_positions is not None @@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds = None assert intermediate_tensors is not None + batch_descriptor = BatchDescriptor( + num_tokens=input_batch.num_tokens_after_padding, + has_lora=self.lora_config is not None, + ) + with set_forward_context( input_batch.attn_metadata, self.vllm_config, num_tokens=input_batch.num_tokens_after_padding, - # TODO(woosuk): Support piecewise CUDA graph. - cudagraph_runtime_mode=CUDAGraphMode.NONE, + cudagraph_runtime_mode=cudagraph_runtime_mode, num_tokens_across_dp=num_tokens_across_dp, + batch_descriptor=batch_descriptor, slot_mapping=input_batch.slot_mappings, ): self.kv_connector.pre_forward(scheduler_output) diff --git a/vllm/v1/worker/gpu/spec_decode/eagle.py b/vllm/v1/worker/gpu/spec_decode/eagle.py index af56c23bf..abbde270f 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle.py @@ -7,7 +7,7 @@ import torch.nn as nn from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.triton_utils import tl, triton @@ -103,14 +103,17 @@ class EagleSpeculator: attn_metadata: dict[str, Any] | None, slot_mappings: dict[str, torch.Tensor] | None, num_tokens_across_dp: torch.Tensor | None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, ) -> tuple[torch.Tensor, torch.Tensor]: + batch_descriptor = BatchDescriptor(num_tokens=num_tokens) with set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, + cudagraph_runtime_mode=cudagraph_runtime_mode, num_tokens_across_dp=num_tokens_across_dp, slot_mapping=slot_mappings, + batch_descriptor=batch_descriptor, ): ret_hidden_states = self.model( input_ids=self.input_buffers.input_ids[:num_tokens], @@ -127,9 +130,11 @@ class EagleSpeculator: def generate_draft( self, num_reqs: int, + num_tokens_padded: int, attn_metadata: dict[str, Any], slot_mappings: dict[str, torch.Tensor], num_tokens_across_dp: torch.Tensor | None, + cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, ) -> None: pos = self.input_buffers.positions[:num_reqs] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] @@ -137,8 +142,14 @@ class EagleSpeculator: for step in range(1, self.num_speculative_steps): # Run the eagle model. last_hidden_states, hidden_states = self.run_model( - num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp + num_tokens_padded, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + cudagraph_runtime_mode, ) + last_hidden_states = last_hidden_states[:num_reqs] + hidden_states = hidden_states[:num_reqs] logits = self.model.compute_logits(last_hidden_states) # NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise @@ -283,12 +294,14 @@ class EagleSpeculator: ) cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs) - if cudagraph_size is not None: - # Run CUDA graph. - self.cudagraph_manager.run(cudagraph_size) + cudagraph_mode = self.cudagraph_manager.cudagraph_mode + if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL: + # Run full CUDA graph. + self.cudagraph_manager.run_fullgraph(cudagraph_size) return self.draft_tokens[:num_reqs] - # Run eager mode. + # Run eager or piecewise CUDA graph. + num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs query_start_loc_cpu = torch.arange( num_reqs + 1, dtype=torch.int32, device="cpu" ) @@ -312,8 +325,13 @@ class EagleSpeculator: slot_mappings, self.kv_cache_config ) self.generate_draft( - num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None - ) # FIXME + num_reqs, + num_tokens_padded, + attn_metadata, + slot_mappings_by_layer, + num_tokens_across_dp=None, # FIXME + cudagraph_runtime_mode=cudagraph_mode, + ) return self.draft_tokens[:num_reqs] diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py index 1ea7ffcb5..ae7aa4078 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable +from typing import Any import torch @@ -31,16 +32,17 @@ class EagleCudaGraphManager: self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None - self.cudagraph_mode = self.compilation_config.cudagraph_mode - if self.cudagraph_mode == CUDAGraphMode.FULL: - # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. - self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY + # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. + self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode() - self.cudagraph_sizes = get_cudagraph_sizes( + # only need to capture uniform decode cudagraph sizes (the 2nd return value) + _, self.cudagraph_sizes = get_cudagraph_sizes( self.compilation_config.cudagraph_capture_sizes, self.max_num_reqs, self.max_num_tokens, self.cudagraph_mode, + uniform_decode_query_len=1, + uniform_decode_cudagraph=True, ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} @@ -54,12 +56,21 @@ class EagleCudaGraphManager: def capture_graph( self, num_tokens: int, + capture_cg_mode: CUDAGraphMode, generate_fn: Callable, input_buffers: InputBuffers, block_tables: BlockTables, attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: + assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( + f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}" + ) + if capture_cg_mode == CUDAGraphMode.PIECEWISE: + capture_fn = self._capture_piecewise_graph + else: + capture_fn = self._capture_full_graph + num_reqs = min(num_tokens, self.max_num_reqs) attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, @@ -69,19 +80,70 @@ class EagleCudaGraphManager: attn_metadata_builders, self.max_model_len, kv_cache_config, + uniform_decode_query_len=1, ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) # Warm up. - generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) + generate_fn( + num_reqs, + num_tokens, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + CUDAGraphMode.NONE, + ) # Capture the graph. + capture_fn( + num_reqs=num_reqs, + num_tokens=num_tokens, + generate_fn=generate_fn, + attn_metadata=attn_metadata, + slot_mappings=slot_mappings, + num_tokens_across_dp=num_tokens_across_dp, + ) + + def _capture_full_graph( + self, + num_reqs: int, + num_tokens: int, + generate_fn: Callable, + attn_metadata: dict[str, Any], + slot_mappings: dict[str, torch.Tensor], + num_tokens_across_dp: torch.Tensor, + ) -> None: assert num_tokens not in self.graphs graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph, self.pool): - generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp) + generate_fn( + num_reqs, + num_tokens, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + CUDAGraphMode.NONE, + ) self.graphs[num_tokens] = graph + def _capture_piecewise_graph( + self, + num_reqs: int, + num_tokens: int, + generate_fn: Callable, + attn_metadata: dict[str, Any], + slot_mappings: dict[str, torch.Tensor], + num_tokens_across_dp: torch.Tensor, + ) -> None: + generate_fn( + num_reqs, + num_tokens, + attn_metadata, + slot_mappings, + num_tokens_across_dp, + CUDAGraphMode.PIECEWISE, + ) + @torch.inference_mode() def capture( self, @@ -91,10 +153,15 @@ class EagleCudaGraphManager: attn_metadata_builders: list[AttentionMetadataBuilder], kv_cache_config: KVCacheConfig, ) -> None: + if self.cudagraph_mode == CUDAGraphMode.NONE: + return + capture_graphs( self.cudagraph_sizes, self.device, self.capture_graph, + capture_cudagraph_mode=self.cudagraph_mode, + desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})", generate_fn=generate_fn, input_buffers=input_buffers, block_tables=block_tables, @@ -102,6 +169,6 @@ class EagleCudaGraphManager: kv_cache_config=kv_cache_config, ) - def run(self, num_tokens: int) -> None: + def run_fullgraph(self, num_tokens: int) -> None: assert num_tokens in self.graphs self.graphs[num_tokens].replay()