[Model Runner V2] Refactor get_cudagraph_and_dp_padding (#32625)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-19 18:25:02 -08:00
committed by GitHub
parent 12dab78f49
commit 7b7cdce968
2 changed files with 54 additions and 67 deletions

View File

@@ -6,6 +6,12 @@ import torch.distributed as dist
from vllm.distributed.parallel_state import get_dp_group
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
def get_batch_metadata_across_dp(
num_tokens: int,
cudagraph_size: int,
@@ -22,10 +28,39 @@ def get_batch_metadata_across_dp(
return tensor[0], tensor[1]
def make_num_tokens_across_dp(
dp_size: int,
def get_cudagraph_and_dp_padding(
num_tokens: int,
) -> torch.Tensor | None:
cudagraph_size: int | None,
dp_size: int,
dp_rank: int,
) -> tuple[bool, int, torch.Tensor | None]:
if dp_size == 1:
return None
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
if cudagraph_size is not None:
return True, cudagraph_size, None
else:
return False, num_tokens, None
if num_tokens == 0:
cudagraph_size = 0
elif cudagraph_size is None:
cudagraph_size = -1
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
num_tokens, cudagraph_size, dp_size, dp_rank
)
if torch.all(num_tokens_across_dp == 0).item():
# All ranks have zero tokens to run.
return False, 0, None
if torch.all(cudagraph_size_across_dp != -1).item():
# All ranks use CUDA graph or have zero tokens.
# Use CUDA graph for all ranks.
# Pad all ranks to the maximum CUDA graph size.
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
num_tokens_across_dp[:] = max_cudagraph_size
return True, max_cudagraph_size, num_tokens_across_dp
else:
# Some ranks do not use CUDA graph. Use eager mode for all ranks.
# No padding is needed except for ranks that have no tokens to run.
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
return False, num_tokens_after_padding, num_tokens_across_dp

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from collections.abc import Iterable
from copy import deepcopy
from typing import Any
@@ -37,7 +36,7 @@ from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.buffer_utils import UvaBufferPool
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
from vllm.v1.worker.gpu.dp_utils import (
get_batch_metadata_across_dp,
get_cudagraph_and_dp_padding,
make_num_tokens_across_dp,
)
from vllm.v1.worker.gpu.input_batch import (
@@ -877,60 +876,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
return draft_tokens
def get_cudagraph_and_dp_padding(
self,
total_num_scheduled_tokens: int,
num_tokens_per_request: Iterable[int],
) -> tuple[CUDAGraphMode, int, torch.Tensor | None]:
dp_size = self.parallel_config.data_parallel_size
if dp_size == 1:
# No DP. Only consider CUDA graphs.
if total_num_scheduled_tokens == 0:
# Special case: no tokens to run.
return CUDAGraphMode.NONE, 0, None
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
total_num_scheduled_tokens, num_tokens_per_request
)
if cudagraph_size is not None:
# Use full CUDA graph.
return CUDAGraphMode.FULL, cudagraph_size, None
# Fall back to eager mode.
# TODO(woosuk): Support piecewise CUDA graphs.
return CUDAGraphMode.NONE, total_num_scheduled_tokens, None
# Consider DP padding and CUDA graph.
if total_num_scheduled_tokens == 0:
# Special handling is needed for 0.
cudagraph_size_before_dp: int | None = 0
else:
cudagraph_size_before_dp = self.cudagraph_manager.get_cudagraph_size(
total_num_scheduled_tokens, num_tokens_per_request
)
if cudagraph_size_before_dp is None:
cudagraph_size_before_dp = -1
assert cudagraph_size_before_dp is not None
dp_rank = self.parallel_config.data_parallel_rank
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
total_num_scheduled_tokens,
cudagraph_size_before_dp,
dp_size,
dp_rank,
)
if all(cudagraph_size_across_dp >= 0):
# If all ranks can use CUDA graph, pad to the maximum number of tokens
# across DP and use CUDA graph.
num_tokens_after_padding = int(cudagraph_size_across_dp.max().item())
cudagraph_mode = CUDAGraphMode.FULL
else:
# If any of the ranks cannot use CUDA graph, use eager mode for all ranks.
# No padding is needed except for ranks that have no tokens to run.
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
num_tokens_after_padding = num_tokens_across_dp[dp_rank]
cudagraph_mode = CUDAGraphMode.NONE
return cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp
@torch.inference_mode()
def execute_model(
self,
@@ -951,10 +896,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No need to run the model.
return EMPTY_MODEL_RUNNER_OUTPUT
cudagraph_mode, num_tokens_after_padding, num_tokens_across_dp = (
self.get_cudagraph_and_dp_padding(
# Get the CUDA graph size. None means no CUDA graph is used.
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens.values(),
)
use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = (
get_cudagraph_and_dp_padding(
scheduler_output.total_num_scheduled_tokens,
scheduler_output.num_scheduled_tokens.values(),
cudagraph_size,
self.parallel_config.data_parallel_size,
self.parallel_config.data_parallel_rank,
)
)
if num_tokens_after_padding == 0:
@@ -1006,7 +958,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# FIXME(woosuk): Fix warmup for LoRA.
# Run model.
if cudagraph_mode == CUDAGraphMode.FULL:
if use_cudagraph:
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
@@ -1015,7 +967,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
else:
# Run PyTorch model in eager mode.
# TODO(woosuk): Support piecewise CUDA graph.
positions = input_batch.positions
if self.uses_mrope:
assert input_batch.mrope_positions is not None
@@ -1024,7 +975,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_mode,
# TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = self.model(