[Feature] ViT Full CUDA Graph (#35963)

Signed-off-by: Baorun Mu <bmu@nvidia.com>
This commit is contained in:
Baorun (Lauren) Mu
2026-03-23 01:01:10 -04:00
committed by GitHub
parent 1f0d210641
commit f85e479e66
7 changed files with 1584 additions and 31 deletions

View File

@@ -0,0 +1,451 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for EncoderCudaGraphManager.
Test organization:
No GPU required:
- TestFindBudgetGraph — greedy budget selection logic
- TestGetCumulativeStats — hit/miss rate statistics
GPU required:
- TestEncoderCudaGraphCaptureReplay — capture, replay, fallback, counters, chunking
"""
from typing import Any
import pytest
import torch
from vllm.platforms import current_platform
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
EncoderCudaGraphManager,
)
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
EncoderCudaGraphConfig,
EncoderCudaGraphReplayBuffers,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_manager_with_budgets(budgets: list[int]) -> EncoderCudaGraphManager:
"""Create a minimal EncoderCudaGraphManager with only token_budgets set.
Skips the parts of __init__ that require a real VllmConfig / model
by patching the attributes directly after construction.
"""
mgr = object.__new__(EncoderCudaGraphManager)
mgr.token_budgets = sorted(budgets)
mgr.max_batch_size = 16
mgr.use_dp = False
mgr.budget_graphs = {}
mgr.graph_hits = 0
mgr.graph_misses = 0
mgr.log_stats_interval = 100
return mgr
# ---------------------------------------------------------------------------
# _generate_budgets
# ---------------------------------------------------------------------------
class TestGenerateBudgets:
"""Auto-generate power-of-2 budgets from min to max."""
def test_exact_powers_of_2(self):
result = EncoderCudaGraphManager._generate_budgets(64, 1024)
assert result == [64, 128, 256, 512, 1024]
def test_max_not_power_of_2(self):
result = EncoderCudaGraphManager._generate_budgets(64, 800)
assert result == [64, 128, 256, 512, 800]
def test_min_equals_max(self):
result = EncoderCudaGraphManager._generate_budgets(64, 64)
assert result == [64]
def test_large_range(self):
result = EncoderCudaGraphManager._generate_budgets(64, 8192)
assert result == [64, 128, 256, 512, 1024, 2048, 4096, 8192]
# ---------------------------------------------------------------------------
# _find_smallest_fitting_budget_given_tokens
# ---------------------------------------------------------------------------
class TestFindBudgetGraph:
"""Budget greedy selection: smallest budget >= total_tokens."""
@pytest.mark.parametrize(
"total_tokens,budgets,expected",
[
# Exact match
(2048, [2048, 4096, 8192], 2048),
# Below smallest budget — picks smallest
(100, [2048, 4096, 8192], 2048),
# Zero tokens — picks smallest
(0, [2048, 4096, 8192], 2048),
# Between budgets — picks next one up
(2049, [2048, 4096, 8192], 4096),
(4097, [2048, 4096, 8192], 8192),
# Exceeds all budgets — returns None (eager fallback)
(9000, [2048, 4096, 8192], None),
# Single budget, fits
(1000, [2048], 2048),
# Single budget, does not fit
(3000, [2048], None),
],
)
def test_find_budget(self, total_tokens, budgets, expected):
mgr = _make_manager_with_budgets(budgets)
result = mgr._find_smallest_fitting_budget_given_tokens(total_tokens)
assert result == expected
def test_budgets_are_sorted(self):
"""Manager always sorts budgets ascending at init."""
mgr = _make_manager_with_budgets([8192, 2048, 4096])
assert mgr.token_budgets == [2048, 4096, 8192]
# Budget selection still works correctly after sorting
assert mgr._find_smallest_fitting_budget_given_tokens(3000) == 4096
# ---------------------------------------------------------------------------
# get_cumulative_stats
# ---------------------------------------------------------------------------
class TestGetCumulativeStats:
"""Statistics tracking and reporting."""
def test_initial_stats_are_zero(self):
mgr = _make_manager_with_budgets([2048])
stats = mgr.get_cumulative_stats()
assert stats["graph_hits"] == 0
assert stats["graph_misses"] == 0
assert stats["hit_rate"] == 0.0
def test_hit_rate_calculation(self):
mgr = _make_manager_with_budgets([2048])
mgr.graph_hits = 75
mgr.graph_misses = 25
stats = mgr.get_cumulative_stats()
assert stats["graph_hits"] == 75
assert stats["graph_misses"] == 25
assert stats["hit_rate"] == pytest.approx(0.75)
def test_all_hits(self):
mgr = _make_manager_with_budgets([2048])
mgr.graph_hits = 100
mgr.graph_misses = 0
assert mgr.get_cumulative_stats()["hit_rate"] == pytest.approx(1.0)
def test_all_misses(self):
mgr = _make_manager_with_budgets([2048])
mgr.graph_hits = 0
mgr.graph_misses = 50
assert mgr.get_cumulative_stats()["hit_rate"] == pytest.approx(0.0)
def test_stats_report_budget_info(self):
budgets = [2048, 4096, 8192]
mgr = _make_manager_with_budgets(budgets)
stats = mgr.get_cumulative_stats()
assert stats["num_budgets"] == 0 # no graphs captured yet
assert stats["token_budgets"] == budgets
# ---------------------------------------------------------------------------
# GPU fixtures and helpers
# ---------------------------------------------------------------------------
# Mock encoder parameters (kept small for fast capture)
_SPATIAL_MERGE = 2
_HIDDEN = 32
_PATCH_SIZE = 4 # H/W per patch in grid_thw units
_TEMPORAL_PATCH = 1
_IN_CHANNELS = 3
# flattened_patch_size = in_channels * temporal_patch * patch_size^2
_FLAT = _IN_CHANNELS * _TEMPORAL_PATCH * _PATCH_SIZE * _PATCH_SIZE # 48
# Test budgets: small to keep capture fast
_BUDGETS = [16, 64]
_MAX_BATCH = 4
def _count_input_patches(grid_thw_list: list[list[int]]) -> int:
return sum(t * h * w for t, h, w in grid_thw_list)
def _count_output_tokens(
grid_thw_list: list[list[int]], spatial_merge_size: int
) -> int:
m = spatial_merge_size
return sum(t * (h // m) * (w // m) for t, h, w in grid_thw_list)
class SimpleMockViTModel(torch.nn.Module):
"""Minimal ViT model for CUDA graph tests.
Implements the SupportsEncoderCudaGraph protocol by providing
all required methods. The forward pass projects patches and
simulates spatial merge by averaging groups of m^2 patches.
"""
supports_encoder_cudagraph = True
def __init__(self):
super().__init__()
self.proj = torch.nn.Linear(_FLAT, _HIDDEN)
self.spatial_merge_size = _SPATIAL_MERGE
self.out_hidden_size = _HIDDEN
def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
return EncoderCudaGraphConfig(
modalities=["image"],
input_key="pixel_values",
buffer_keys=["dummy_buf"],
out_hidden_size=_HIDDEN,
)
def get_encoder_cudagraph_budget_range(
self,
vllm_config,
) -> tuple[int, int]:
# For tests: min=4, max=128 (small values for fast capture)
return (4, 128)
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(mm_kwargs["image_grid_thw"])
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
m = _SPATIAL_MERGE
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = mm_kwargs["pixel_values"]
if len(indices) == 0:
return {
"pixel_values": pixel_values[:0],
"image_grid_thw": [],
}
patches_per_item = [t * h * w for t, h, w in grid_thw]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)
selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
selected_grid = [grid_thw[i] for i in indices]
return {
"pixel_values": selected_pv,
"image_grid_thw": selected_grid,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> EncoderCudaGraphCaptureInputs:
per_image_output = token_budget // max_batch_size
grid_config = [
[1, _SPATIAL_MERGE, per_image_output * _SPATIAL_MERGE]
for _ in range(max_batch_size)
]
total_patches = _count_input_patches(grid_config)
dummy_pixel_values = torch.randn(
total_patches, _FLAT, device=device, dtype=dtype
)
n_out = _count_output_tokens(grid_config, _SPATIAL_MERGE)
dummy_buf = torch.zeros(n_out, _HIDDEN, device=device, dtype=dtype)
return EncoderCudaGraphCaptureInputs(
mm_kwargs={
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
},
buffers={"dummy_buf": dummy_buf},
)
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
) -> EncoderCudaGraphReplayBuffers:
grid_thw = mm_kwargs["image_grid_thw"]
n_out = _count_output_tokens(grid_thw, _SPATIAL_MERGE)
p = next(self.parameters())
dummy_buf = torch.zeros(n_out, _HIDDEN, device=p.device, dtype=p.dtype)
return EncoderCudaGraphReplayBuffers(buffers={"dummy_buf": dummy_buf})
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
return self._forward(mm_kwargs["pixel_values"])
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
return self._forward(mm_kwargs["pixel_values"])
def _forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
m2 = _SPATIAL_MERGE**2
out = self.proj(pixel_values)
n_out = out.shape[0] // m2
return out[: n_out * m2].view(n_out, m2, _HIDDEN).mean(dim=1)
def _make_manager_for_gpu(
model: SimpleMockViTModel,
token_budgets: list[int],
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> EncoderCudaGraphManager:
"""Create EncoderCudaGraphManager bypassing VllmConfig for GPU tests."""
mgr = object.__new__(EncoderCudaGraphManager)
mgr.token_budgets = sorted(token_budgets)
mgr.max_batch_size = max_batch_size
mgr.use_dp = False
mgr.budget_graphs = {}
mgr.graph_hits = 0
mgr.graph_misses = 0
mgr.log_stats_interval = 100
mgr.model = model
mgr.config = model.get_encoder_cudagraph_config()
mgr.device = device
mgr.dtype = dtype
return mgr
def _make_pixel_values(
grid_thw_list: list[list[int]],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
"""Random pixel_values matching the total input patch count."""
n = _count_input_patches(grid_thw_list)
return torch.randn(n, _FLAT, device=device, dtype=dtype)
def _make_mm_kwargs(
grid_thw_list: list[list[int]],
device: torch.device,
dtype: torch.dtype,
) -> dict[str, Any]:
"""Create mm_kwargs for testing."""
return {
"pixel_values": _make_pixel_values(grid_thw_list, device, dtype),
"image_grid_thw": grid_thw_list,
}
# ---------------------------------------------------------------------------
# GPU tests — capture, replay, fallback, counters, chunking
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestEncoderCudaGraphCaptureReplay:
def setup_method(self):
self.device = torch.device("cuda:0")
self.dtype = torch.float16
self.model = SimpleMockViTModel().to(self.device).half()
self.mgr = _make_manager_for_gpu(
self.model, _BUDGETS, _MAX_BATCH, self.device, self.dtype
)
self.mgr.capture()
# --- capture ---
def test_capture_creates_one_graph_per_budget(self):
assert len(self.mgr.budget_graphs) == len(_BUDGETS)
assert set(self.mgr.budget_graphs.keys()) == set(_BUDGETS)
# --- output shape ---
def test_execute_returns_one_tensor_per_image(self):
grid_thw = [[1, 4, 4], [1, 4, 4]]
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert len(result) == 2
def test_execute_output_tokens_per_image(self):
# [1,4,4] → 1*(4//2)*(4//2) = 4 tokens; [1,8,8] → 16 tokens
grid_thw = [[1, 4, 4], [1, 8, 8]]
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert result[0].shape == (4, _HIDDEN)
assert result[1].shape == (16, _HIDDEN)
# --- budget fallback ---
def test_eager_fallback_when_tokens_exceed_all_budgets(self):
# [1,18,18] → 1*(18//2)*(18//2) = 81 tokens > max budget 64.
# Greedy packing handles the fallback internally: the oversized image
# gets an eager forward pass and is returned as part of the output list
# (execute() no longer returns None for individual image misses).
grid_thw = [[1, 18, 18]]
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert len(result) == 1
# Eager output: SimpleMockViTModel produces n_out = 81 tokens
assert result[0].shape == (81, _HIDDEN)
assert self.mgr.graph_misses == 1
# --- counters ---
def test_hit_counter_increments_by_num_images(self):
grid_thw = [[1, 4, 4], [1, 4, 4]]
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
self.mgr.execute(mm_kwargs)
assert self.mgr.graph_hits == 2
def test_miss_counter_increments_by_num_images(self):
grid_thw = [[1, 18, 18]] # 81 tokens > 64
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
self.mgr.execute(mm_kwargs)
assert self.mgr.graph_misses == 1
# --- chunking ---
def test_chunking_when_images_exceed_max_batch(self):
# 8 images > max_batch_size=4 → 2 chunks of 4
# each chunk: 4 * 4 = 16 tokens → fits budget 16
n_images = _MAX_BATCH * 2
grid_thw = [[1, 4, 4]] * n_images
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
result = self.mgr.execute(mm_kwargs)
assert result is not None
assert len(result) == n_images
for out in result:
assert out.shape == (4, _HIDDEN)

View File

@@ -489,6 +489,28 @@ class CompilationConfig:
on selected platforms. Disabled by default until more models
are supported/tested to work."""
# Vision encoder CUDA graph
cudagraph_mm_encoder: bool = False
"""Enable CUDA graph capture for multimodal encoder (ViT).
When enabled, captures full encoder forward as CUDA graph
for each token budget level."""
encoder_cudagraph_token_budgets: list[int] = field(default_factory=list)
"""Token budget levels for encoder CUDA graph capture.
Each budget defines a fixed token capacity. At runtime, images are greedy-packed
into the smallest fitting budget and the corresponding CUDA graph is replayed.
If empty (default), auto-inferred from model architecture as power-of-2
levels from the model's estimated min budget to max budget.
User-provided values override auto-inference.
Example: [2048, 4096, 8192, 13824]"""
encoder_cudagraph_max_images_per_batch: int = 0
"""Maximum number of images per batch for encoder CUDA graph capture.
Determines the fixed batch size used during graph capture.
If 0 (default), auto-inferred as max_budget // min_budget from the
model's budget range. User-provided positive value overrides
auto-inference."""
# Inductor capture
compile_sizes: list[int | str] | None = None
"""Sizes to compile for inductor. In addition
@@ -906,6 +928,16 @@ class CompilationConfig:
f"Invalid backend for piecewise compilation: {self.backend}"
)
# Validate encoder CUDA graph configuration
if (
self.cudagraph_mm_encoder
and self.encoder_cudagraph_max_images_per_batch < 0
):
raise ValueError(
"encoder_cudagraph_max_images_per_batch must be "
"non-negative (0 = auto-infer)"
)
if self.backend == "":
self.backend = current_platform.get_compile_backend()

View File

@@ -13,6 +13,7 @@ from collections.abc import (
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Protocol,
@@ -46,6 +47,11 @@ if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.registry import _ProcessorFactories
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
EncoderCudaGraphConfig,
EncoderCudaGraphReplayBuffers,
)
else:
VllmConfig = object
WeightsMapper = object
@@ -1494,3 +1500,138 @@ def supports_xdrope(
model: type[object] | object,
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
return isinstance(model, SupportsXDRoPE)
@runtime_checkable
class SupportsEncoderCudaGraph(Protocol):
"""Interface for models whose vision encoder supports CUDA graph
capture/replay.
Models implement these methods to provide the
:class:`EncoderCudaGraphManager` with all model-specific logic
(input handling, metadata computation, forward pass) without the
manager needing to know model internals.
"""
supports_encoder_cudagraph: ClassVar[Literal[True]] = True
def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ...
def get_encoder_cudagraph_budget_range(
self,
vllm_config: "VllmConfig",
) -> tuple[int, int]:
"""Return (min_token_budget, max_token_budget) for auto-inference.
- min_token_budget: estimated smallest possible encoder input
(e.g. 64 for a 224x224 image)
- max_token_budget: estimated largest budget worth capturing
(e.g. max_num_batched_tokens)
Used when ``encoder_cudagraph_token_budgets`` and/or
``encoder_cudagraph_max_images_per_batch`` are not explicitly
specified by the user.
"""
...
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
"""Return the number of items (e.g. images) in the batch."""
...
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
"""Return output token count for each item.
Used for greedy packing and DP load balancing.
"""
...
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
"""Return input size (e.g. patch count) for each item.
Used for input tensor slicing offsets.
"""
...
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
"""Select a subset of items and return mm_kwargs for the sub-batch.
Called by the manager during greedy packing and DP sharding to
extract inputs for a specific set of items (e.g. images at
indices [0, 3, 5]). The implementation is model-specific
because input formats differ:
- Qwen-family: slice concatenated pixel_values by cumulative
patch offsets, subset grid_thw by indices.
- Batched models (CLIP): index pixel_values along dim 0.
"""
...
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> "EncoderCudaGraphCaptureInputs":
"""Create dummy inputs and buffers for CUDA graph capture."""
...
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
) -> "EncoderCudaGraphReplayBuffers":
"""Compute buffer values from actual batch inputs for replay."""
...
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Run the encoder forward pass with precomputed buffers.
Used during both CUDA graph capture and replay.
"""
...
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
"""Run the encoder forward pass without precomputed buffers.
Used as eager fallback when inputs exceed all budgets.
"""
...
@overload
def supports_encoder_cudagraph(
model: type[object],
) -> TypeIs[type[SupportsEncoderCudaGraph]]: ...
@overload
def supports_encoder_cudagraph(
model: object,
) -> TypeIs[SupportsEncoderCudaGraph]: ...
def supports_encoder_cudagraph(
model: type[object] | object,
) -> TypeIs[type[SupportsEncoderCudaGraph]] | TypeIs[SupportsEncoderCudaGraph]:
return isinstance(model, SupportsEncoderCudaGraph)

View File

@@ -103,6 +103,7 @@ from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
@@ -528,54 +529,120 @@ class Qwen3_VisionTransformer(nn.Module):
return torch.cat(outputs, dim=0)
def forward(
def prepare_encoder_metadata(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
grid_thw_list: list[list[int]],
*,
max_batch_size: int | None = None,
max_seqlen_override: int | None = None,
device: torch.device | None = None,
) -> dict[str, torch.Tensor | None]:
"""Compute encoder metadata from grid_thw_list.
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
Shared by the eager forward path, CUDA graph capture, and
CUDA graph replay to avoid duplicated implementation.
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
Args:
grid_thw_list: Grid configurations as list of [t, h, w].
max_batch_size: If set, pad cu_seqlens to this size
(needed for CUDA graph capture/replay).
max_seqlen_override: If set, use this value for max_seqlen
instead of computing from cu_seqlens (needed for CUDA
graph capture to cover worst-case replay scenarios).
device: Device to place tensors on. Defaults to self.device.
"""
if device is None:
device = self.device
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
metadata: dict[str, torch.Tensor | None] = {}
# Positional embeddings
metadata["pos_embeds"] = self.fast_pos_embed_interpolate(grid_thw_list)
rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list)
metadata["rotary_pos_emb_cos"] = rotary_cos
metadata["rotary_pos_emb_sin"] = rotary_sin
# cu_seqlens from grid_thw
grid_thw_np = np.array(grid_thw_list, dtype=np.int32)
patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2]
cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum(
dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, self.device
# Pad cu_seqlens if max_batch_size specified
if max_batch_size is not None:
num_seqs = len(cu_seqlens) - 1
if num_seqs < max_batch_size:
cu_seqlens = np.concatenate(
[
cu_seqlens,
np.full(
max_batch_size - num_seqs,
cu_seqlens[-1],
dtype=np.int32,
),
]
)
# sequence_lengths (backend-specific)
metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, device
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
)
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
# max_seqlen
if max_seqlen_override is not None:
max_seqlen_val = max_seqlen_override
else:
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
self.attn_backend, cu_seqlens
)
# Keep max_seqlen on CPU: attention wrappers call .item() on it,
# and having it on GPU would capture a wasteful D2H copy in CUDA
# graphs without changing behavior (the scalar is baked at capture).
metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32)
# Recompute cu_seqlens (backend-specific transformation)
metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens,
self.hidden_size,
self.tp_size,
self.device,
device,
)
return metadata
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
*,
encoder_metadata: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
if encoder_metadata is None:
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
else:
grid_thw_list = grid_thw.tolist()
encoder_metadata = self.prepare_encoder_metadata(grid_thw_list)
pos_embeds = encoder_metadata["pos_embeds"]
hidden_states = hidden_states + pos_embeds
hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
cu_seqlens=encoder_metadata["cu_seqlens"],
rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"],
rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"],
max_seqlen=encoder_metadata["max_seqlen"],
sequence_lengths=encoder_metadata.get("sequence_lengths"),
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
@@ -1358,6 +1425,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
class Qwen3VLForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
@@ -1507,6 +1575,178 @@ class Qwen3VLForConditionalGeneration(
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
# -- SupportsEncoderCudaGraph protocol methods --
def get_encoder_cudagraph_config(self):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)
return EncoderCudaGraphConfig(
modalities=["image"],
input_key="pixel_values",
buffer_keys=[
"pos_embeds",
"rotary_pos_emb_cos",
"rotary_pos_emb_sin",
"cu_seqlens",
"max_seqlen",
"sequence_lengths",
],
out_hidden_size=self.visual.out_hidden_size,
)
def get_encoder_cudagraph_budget_range(
self,
vllm_config,
) -> tuple[int, int]:
# Min: estimated smallest possible encoder input.
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
min_budget = 64
# Max: capped by max_num_batched_tokens
max_budget = vllm_config.scheduler_config.max_num_batched_tokens
return (min_budget, max_budget)
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(mm_kwargs["image_grid_thw"])
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
m = self.visual.spatial_merge_size
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = mm_kwargs["pixel_values"]
if len(indices) == 0:
return {
"pixel_values": pixel_values[:0],
"image_grid_thw": [],
}
# Compute cumulative patch offsets for slicing pixel_values
patches_per_item = [t * h * w for t, h, w in grid_thw]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)
selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
selected_grid = [grid_thw[i] for i in indices]
return {
"pixel_values": selected_pv,
"image_grid_thw": selected_grid,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
)
spatial_merge_size = self.visual.spatial_merge_size
per_image_output = token_budget // max_batch_size
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
# produces exactly per_image_output tokens per image.
grid_config = [
[1, spatial_merge_size, per_image_output * spatial_merge_size]
for _ in range(max_batch_size)
]
# Create dummy pixel_values
patch_embed = self.visual.patch_embed
in_channels = patch_embed.proj.in_channels
patch_size = patch_embed.patch_size
temporal_patch_size = patch_embed.temporal_patch_size
total_patches = sum(t * h * w for t, h, w in grid_config)
flattened_patch_size = (
in_channels * temporal_patch_size * patch_size * patch_size
)
dummy_pixel_values = torch.randn(
total_patches, flattened_patch_size, device=device, dtype=dtype
)
# Override max_seqlen with a safe upper bound for capture.
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
# so the capture value must cover any replay scenario.
# Worst case: 1 image consuming the full budget ->
# seq_len = token_budget * spatial_merge_size^2.
buffers = self.visual.prepare_encoder_metadata(
grid_config,
max_batch_size=max_batch_size,
max_seqlen_override=token_budget * (spatial_merge_size**2),
device=device,
)
mm_kwargs = {
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
}
return EncoderCudaGraphCaptureInputs(
mm_kwargs=mm_kwargs,
buffers=buffers,
)
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphReplayBuffers,
)
grid_thw_list = mm_kwargs["image_grid_thw"]
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_batch_size=max_batch_size,
)
return EncoderCudaGraphReplayBuffers(buffers=buffers)
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
return self.visual(pixel_values, grid_thw)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None:

View File

@@ -0,0 +1,576 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""CUDA graph manager for vision encoder budget-batch execution."""
from dataclasses import dataclass
from typing import Any
import torch
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsEncoderCudaGraph
from vllm.model_executor.models.vision import get_load_balance_assignment
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)
logger = init_logger(__name__)
@dataclass
class BudgetGraphMetadata:
"""Metadata for a single budget graph.
CUDA graph replay pattern:
1. Copy new batch data into input_buffer (e.g. pixel_values)
2. Copy precomputed values into metadata_buffers
3. Replay graph
4. Read encoder outputs from output_buffer
"""
token_budget: int
max_batch_size: int # Max number of images/videos per batch
graph: torch.cuda.CUDAGraph
# The input tensor updated before replay (e.g. pixel_values)
input_buffer: torch.Tensor
# Buffers recorded into the CUDA graph (e.g. embeddings, sequence metadata).
# Before replay the manager zeros then slice-copies new data into these.
metadata_buffers: dict[str, torch.Tensor]
# Output written by graph, read after replay
output_buffer: torch.Tensor
class EncoderCudaGraphManager:
"""Budget-based CUDA graph capture/replay for vision encoders."""
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
dtype: torch.dtype,
model: SupportsEncoderCudaGraph,
):
"""Initialize CUDA graph manager with provided token budgets
and max batch size."""
self.vllm_config = vllm_config
self.device = device
self.dtype = dtype
self.model = model
self.config: EncoderCudaGraphConfig = model.get_encoder_cudagraph_config()
comp_config = vllm_config.compilation_config
user_budgets = comp_config.encoder_cudagraph_token_budgets
user_max_images = comp_config.encoder_cudagraph_max_images_per_batch
if user_budgets and user_max_images > 0:
# Fully user-specified
self.token_budgets = sorted(user_budgets)
self.max_batch_size = user_max_images
else:
# Auto-infer missing values from model
min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
vllm_config
)
self.token_budgets = (
sorted(user_budgets)
if user_budgets
else self._generate_budgets(min_budget, max_budget)
)
self.max_batch_size = (
user_max_images if user_max_images > 0 else max_budget // min_budget
)
mm_config = vllm_config.model_config.multimodal_config
self.use_dp = (
mm_config is not None
and mm_config.mm_encoder_tp_mode == "data"
and vllm_config.parallel_config.tensor_parallel_size > 1
)
self.budget_graphs: dict[int, BudgetGraphMetadata] = {}
self.graph_hits = 0
self.graph_misses = 0
self.log_stats_interval = 100
logger.info(
"EncoderCudaGraphManager initialized with "
"budgets=%s, max_batch_size=%d, use_dp=%s",
self.token_budgets,
self.max_batch_size,
self.use_dp,
)
@staticmethod
def _generate_budgets(min_budget: int, max_budget: int) -> list[int]:
"""Generate power-of-2 token budgets from min_budget to max_budget."""
budgets: list[int] = []
b = min_budget
while b <= max_budget:
budgets.append(b)
b *= 2
# Always include max_budget if it's not already a power-of-2 boundary
if not budgets or budgets[-1] < max_budget:
budgets.append(max_budget)
return budgets
def supports_modality(self, modality: str) -> bool:
"""Check if a modality is supported by this manager."""
return modality in self.config.modalities
def capture(self):
"""Capture CUDA graphs for all token budgets."""
for token_budget in self.token_budgets:
self._capture_budget_graph(token_budget)
logger.info(
"Encoder CUDA graph capture complete. Captured %d budget graphs.",
len(self.budget_graphs),
)
def _capture_budget_graph(self, token_budget: int):
"""Capture CUDA graph for a single token budget."""
logger.debug(
"Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
token_budget,
self.max_batch_size,
)
capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
token_budget, self.max_batch_size, self.device, self.dtype
)
mm_kwargs = capture_inputs.mm_kwargs
buffers = capture_inputs.buffers
with torch.inference_mode():
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
output_buffer = torch.empty_like(output)
graph = torch.cuda.CUDAGraph()
with torch.inference_mode(), torch.cuda.graph(graph):
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
output_buffer.copy_(output)
input_key = self.config.input_key
self.budget_graphs[token_budget] = BudgetGraphMetadata(
token_budget=token_budget,
max_batch_size=self.max_batch_size,
graph=graph,
input_buffer=mm_kwargs[input_key],
metadata_buffers=buffers,
output_buffer=output_buffer,
)
def _find_smallest_fitting_budget_given_tokens(
self, total_tokens: int
) -> int | None:
"""Find smallest budget >= total_tokens.
Returns:
Token budget if found, None if no fitting budget.
"""
for budget in self.token_budgets:
if budget >= total_tokens:
return budget
return None
def _get_per_item_out_tokens(self, mm_kwargs: dict[str, Any]) -> list[int]:
"""Get per-item output token counts as plain ints."""
return [
int(t)
for t in self.model.get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)
]
@staticmethod
def _scatter_output_slices(
output: torch.Tensor,
indices: list[int],
per_item_out_tokens: list[int],
dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
clone: bool = False,
) -> None:
"""Slice a concatenated output tensor and scatter into dest by index."""
offset = 0
for idx in indices:
n_tok = per_item_out_tokens[idx]
sliced = output[offset : offset + n_tok]
dest[idx] = sliced.clone() if clone else sliced
offset += n_tok
def _run_budget_graph(
self,
mm_kwargs: dict[str, Any],
token_budget: int,
replay_buffers: dict[str, torch.Tensor | None],
) -> torch.Tensor | None:
"""Execute budget graph.
Args:
mm_kwargs: Multimodal inputs for the batch.
token_budget: Token budget to use.
replay_buffers: Buffer values to copy into captured buffers.
None values leave the corresponding buffer unchanged.
Returns:
Encoder outputs, or None if graph not captured.
"""
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
if token_budget not in self.budget_graphs:
self.graph_misses += num_items
return None
graph_meta = self.budget_graphs[token_budget]
# Copy the input tensor. Buffers are sized for the full budget;
# actual inputs may be smaller. Zero then slice-copy so padded
# positions are invisible to attention (cu_seqlens masks them out).
input_key = self.config.input_key
src = mm_kwargs[input_key]
n = src.shape[0]
graph_meta.input_buffer.zero_()
graph_meta.input_buffer[:n].copy_(src)
# Copy metadata buffers using keys from config.buffer_keys.
for key in self.config.buffer_keys:
src = replay_buffers.get(key)
if src is None:
continue
buf = graph_meta.metadata_buffers[key]
if src.ndim == 0:
buf.copy_(src)
else:
n = src.shape[0]
buf.zero_()
buf[:n].copy_(src)
graph_meta.graph.replay()
self.graph_hits += num_items
return graph_meta.output_buffer
def _execute_local(
self,
mm_kwargs: dict[str, Any],
) -> list[torch.Tensor]:
"""Execute encoder on local inputs using greedy-packed CUDA graphs.
Sort images by output token count (smallest first), then greedily pack
as many images as possible into each batch while staying within
max_budget tokens and max_batch_size. Once a batch is finalised (next
image would overflow either constraint), find the smallest fitting
budget once for that batch.
By exchange argument, greedy smallest-first packing minimises eager
fallbacks -- any other ordering yields a higher token sum in some batch,
making that batch more likely to exceed the budget.
Stats note:
graph_hits -- counted inside _run_budget_graph after successful replay.
graph_misses -- counted here for single-image batches where the image
exceeds max_budget. Batches split due to max_batch_size
always satisfy total_tokens <= max_budget and therefore
always find a valid budget (no miss).
"""
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
max_budget = self.token_budgets[-1]
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
# Sort ascending by output token count (smallest first)
sorted_indices = sorted(range(num_items), key=lambda i: per_item_out_tokens[i])
# Greedy pack against max_budget and max_batch_size.
# _find_smallest_fitting_budget_given_tokens is called once per
# finalised batch, not per image.
batches: list[tuple[list[int], int | None]] = []
current_batch: list[int] = []
current_batch_tokens = 0
for orig_idx in sorted_indices:
item_tokens = per_item_out_tokens[orig_idx]
if (
current_batch_tokens + item_tokens <= max_budget
and len(current_batch) < self.max_batch_size
):
current_batch.append(orig_idx)
current_batch_tokens += item_tokens
else:
if current_batch:
batches.append(
(
current_batch,
self._find_smallest_fitting_budget_given_tokens(
current_batch_tokens
),
)
)
current_batch = [orig_idx]
current_batch_tokens = item_tokens
if current_batch:
batches.append(
(
current_batch,
self._find_smallest_fitting_budget_given_tokens(
current_batch_tokens
),
)
)
# outputs_by_orig_idx maps each original image index to its output
# tensor. Needed because greedy packing reorders images; we restore
# the original order before returning.
outputs_by_orig_idx: dict[int, torch.Tensor] = {}
for batch_orig_indices, token_budget in batches:
batch_mm_kwargs = self.model.select_encoder_cudagraph_items(
mm_kwargs, batch_orig_indices
)
batch_out_tokens = sum(per_item_out_tokens[i] for i in batch_orig_indices)
if token_budget is None:
# Single oversized image: item_tokens > max_budget.
# graph_misses counted here for this eager fallback.
logger.debug(
"Encoder CUDA graph fallback to eager: no budget for "
"%d tokens from %d images",
batch_out_tokens,
len(batch_orig_indices),
)
self.graph_misses += len(batch_orig_indices)
with torch.inference_mode():
raw = self.model.encoder_eager_forward(batch_mm_kwargs)
self._scatter_output_slices(
raw,
batch_orig_indices,
per_item_out_tokens,
outputs_by_orig_idx,
)
else:
logger.debug(
"Encoder CUDA graph: batch_size=%d, tokens=%d, "
"budget=%d, waste=%.1f%%",
len(batch_orig_indices),
batch_out_tokens,
token_budget,
(token_budget - batch_out_tokens) / token_budget * 100,
)
replay = self.model.prepare_encoder_cudagraph_replay_buffers(
batch_mm_kwargs, self.max_batch_size
)
# graph_hits counted inside _run_budget_graph after replay.
output = self._run_budget_graph(
batch_mm_kwargs, token_budget, replay.buffers
)
assert output is not None
self._scatter_output_slices(
output,
batch_orig_indices,
per_item_out_tokens,
outputs_by_orig_idx,
clone=True,
)
# Return in original batch order (caller maps outputs to token positions)
return [outputs_by_orig_idx[i] for i in range(num_items)]
def _dp_shard(
self,
mm_kwargs: dict[str, Any],
per_item_out_tokens: list[int],
) -> tuple[dict[str, Any], list[int], list[int], int]:
"""Distribute items across TP ranks for data-parallel execution.
Uses get_load_balance_assignment() to balance load by input size,
then select_encoder_cudagraph_items() to extract each rank's inputs.
Returns:
local_mm_kwargs: Inputs for this rank.
image_rank_assignment: Flattened assignment order across all ranks.
images_per_rank: Number of items per rank.
max_output_tokens_per_rank: Max output tokens across all ranks
(for padding during all_gather).
"""
tp_size = get_tensor_model_parallel_world_size()
current_rank = get_tensor_model_parallel_rank()
per_item_input_sizes = self.model.get_encoder_cudagraph_per_item_input_sizes(
mm_kwargs
)
(image_rank_assignment, images_per_rank, input_patches_per_rank) = (
get_load_balance_assignment(per_item_input_sizes, tp_size)
)
# Extract local indices for this rank
cum_images_per_rank = [0]
for count in images_per_rank:
cum_images_per_rank.append(cum_images_per_rank[-1] + count)
local_indices = image_rank_assignment[
cum_images_per_rank[current_rank] : cum_images_per_rank[current_rank + 1]
]
if len(local_indices) > 0:
local_mm_kwargs = self.model.select_encoder_cudagraph_items(
mm_kwargs, local_indices
)
else:
local_mm_kwargs = self.model.select_encoder_cudagraph_items(mm_kwargs, [])
max_output_tokens_per_rank = (
max(
sum(
per_item_out_tokens[i]
for i in image_rank_assignment[
cum_images_per_rank[r] : cum_images_per_rank[r + 1]
]
)
for r in range(tp_size)
)
if len(per_item_out_tokens) > 0
else 0
)
return (
local_mm_kwargs,
image_rank_assignment,
images_per_rank,
max_output_tokens_per_rank,
)
def _dp_gather(
self,
local_outputs: list[torch.Tensor],
per_item_out_tokens: list[int],
image_rank_assignment: list[int],
images_per_rank: list[int],
max_output_tokens_per_rank: int,
) -> list[torch.Tensor]:
"""Gather outputs from all TP ranks and reorder to original sequence.
Assumes 2D output tensors [tokens, hidden]. Follows the same
pad -> all_gather -> unpad -> reorder algorithm as
run_dp_sharded_mrope_vision_model() in the eager path.
"""
hidden_size = self.config.out_hidden_size
tp_size = len(images_per_rank)
if len(local_outputs) > 0:
local_concat = torch.cat(local_outputs, dim=0)
else:
local_concat = torch.empty(
(0, hidden_size), device=self.device, dtype=self.dtype
)
# Pad to max_output_tokens_per_rank for all_gather
current_len = local_concat.shape[0]
if current_len < max_output_tokens_per_rank:
padding = torch.empty(
(max_output_tokens_per_rank - current_len, hidden_size),
dtype=self.dtype,
device=self.device,
)
local_padded = torch.cat([local_concat, padding], dim=0)
else:
local_padded = local_concat
gathered = tensor_model_parallel_all_gather(local_padded, dim=0)
# Unpad each rank's contribution
rank_outputs: list[torch.Tensor] = []
current_idx = 0
for rank in range(tp_size):
start = rank * max_output_tokens_per_rank
rank_count = images_per_rank[rank]
rank_indices = image_rank_assignment[current_idx : current_idx + rank_count]
rank_tokens = sum(per_item_out_tokens[i] for i in rank_indices)
current_idx += rank_count
rank_outputs.append(gathered[start : start + rank_tokens])
# Reorder to original sequence
total_items = len(per_item_out_tokens)
result: list[torch.Tensor | None] = [None] * total_items
current_idx = 0
for rank in range(tp_size):
count = images_per_rank[rank]
if count > 0:
rank_items = image_rank_assignment[current_idx : current_idx + count]
self._scatter_output_slices(
rank_outputs[rank],
rank_items,
per_item_out_tokens,
result,
)
current_idx += count
return [t for t in result if t is not None]
def execute(
self,
mm_kwargs: dict[str, Any],
) -> list[torch.Tensor]:
"""Execute encoder using CUDA graph with optional DP.
Args:
mm_kwargs: Multimodal keyword arguments containing the
input tensor and grid dimensions.
Returns:
List of encoder outputs (one per item).
"""
if self.use_dp:
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
(
local_mm_kwargs,
image_rank_assignment,
images_per_rank,
max_output_tokens_per_rank,
) = self._dp_shard(mm_kwargs, per_item_out_tokens)
local_outputs = self._execute_local(local_mm_kwargs)
result = self._dp_gather(
local_outputs,
per_item_out_tokens,
image_rank_assignment,
images_per_rank,
max_output_tokens_per_rank,
)
else:
result = self._execute_local(mm_kwargs)
# Log cumulative stats periodically
stats = self.get_cumulative_stats()
total_requests = self.graph_hits + self.graph_misses
if total_requests > 0 and total_requests % self.log_stats_interval == 0:
logger.debug(
"Encoder CUDA graph cumulative stats: "
"hits=%d, misses=%d, hit_rate=%.1f%%",
stats["graph_hits"],
stats["graph_misses"],
stats["hit_rate"] * 100,
)
return result
def get_cumulative_stats(self) -> dict[str, Any]:
"""Get cumulative CUDA graph statistics."""
total_requests = self.graph_hits + self.graph_misses
hit_rate = self.graph_hits / total_requests if total_requests > 0 else 0.0
return {
"graph_hits": self.graph_hits,
"graph_misses": self.graph_misses,
"hit_rate": hit_rate,
"num_budgets": len(self.budget_graphs),
"token_budgets": self.token_budgets,
}

View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Data transfer objects for encoder CUDA graph management."""
from dataclasses import dataclass
from typing import Any
import torch
@dataclass
class EncoderCudaGraphConfig:
"""Configuration for encoder CUDA graph management.
Provided by the model at init time via
``get_encoder_cudagraph_config()``. Values are fixed for the
lifetime of the manager.
"""
modalities: list[str]
"""Supported modalities (e.g. ["image"])."""
input_key: str
"""Key in mm_kwargs for the input tensor (e.g. "pixel_values")."""
buffer_keys: list[str]
"""Keys for the tensor buffers recorded into the CUDA graph.
Before replay the manager zeros then slice-copies new data
into these buffers."""
out_hidden_size: int
"""Output hidden dim of the vision encoder.
Used for DP gather buffer allocation."""
@dataclass
class EncoderCudaGraphCaptureInputs:
"""Everything needed for one CUDA graph capture.
Returned by ``prepare_encoder_cudagraph_capture_inputs()``.
"""
mm_kwargs: dict[str, Any]
"""Dummy forward inputs (model-specific keys).
For Qwen3-VL this contains pixel_values and grid_thw."""
buffers: dict[str, torch.Tensor]
"""Precomputed tensor buffers that will be recorded into the
CUDA graph. The manager stores references to these exact
tensor objects and copies new data into them before each
``graph.replay()`` call (buffer identity invariant)."""
@dataclass
class EncoderCudaGraphReplayBuffers:
"""New buffer values for graph replay, computed by the model from
actual batch inputs.
Returned by ``prepare_encoder_cudagraph_replay_buffers()``.
Keys match ``EncoderCudaGraphConfig.buffer_keys``.
"""
buffers: dict[str, torch.Tensor | None]
"""Data to copy into the captured buffers before replay.
``None`` values leave the corresponding captured buffer
unchanged."""

View File

@@ -207,6 +207,7 @@ from .utils import (
if TYPE_CHECKING:
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager
logger = init_logger(__name__)
@@ -499,6 +500,9 @@ class GPUModelRunner(
self.encoder_cache: dict[str, torch.Tensor] = {}
self.late_interaction_runner = LateInteractionRunner()
# Encoder CUDA graph manager (initialized after model load if enabled)
self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
# NOTE(Jiayi): currently we put the entire draft model on
@@ -2664,7 +2668,19 @@ class GPUModelRunner(
with self.timed_encoder_operation(
should_time, mm_lora_refs, current_item_idx, num_items
):
batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
cudagraph_output = None
if (
self.encoder_cudagraph_manager is not None
and self.encoder_cudagraph_manager.supports_modality(modality)
):
cudagraph_output = self.encoder_cudagraph_manager.execute(
mm_kwargs_batch,
)
if cudagraph_output is not None:
batch_outputs = cudagraph_output
else:
batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
encoder_outputs.extend(batch_outputs)
@@ -5715,6 +5731,33 @@ class GPUModelRunner(
)
return 0
# Initialize encoder CUDA graph manager if enabled.
# Use get_model() to unwrap CUDAGraphWrapper/UBatchWrapper,
# because @runtime_checkable Protocol isinstance() checks do not
# work through __getattr__ forwarding.
if (
self.compilation_config.cudagraph_mm_encoder
and self.supports_mm_inputs
and self.encoder_cudagraph_manager is None
):
from vllm.model_executor.models.interfaces import (
SupportsEncoderCudaGraph,
supports_encoder_cudagraph,
)
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
EncoderCudaGraphManager,
)
raw_model = self.get_model()
if supports_encoder_cudagraph(raw_model):
self.encoder_cudagraph_manager = EncoderCudaGraphManager(
vllm_config=self.vllm_config,
device=self.device,
dtype=self.dtype,
model=cast(SupportsEncoderCudaGraph, raw_model),
)
logger.info("Initialized EncoderCudaGraphManager for vision encoder")
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
@@ -5738,6 +5781,10 @@ class GPUModelRunner(
)
torch.accelerator.synchronize()
# Capture encoder CUDA graphs if enabled
if self.encoder_cudagraph_manager is not None:
self.encoder_cudagraph_manager.capture()
torch.accelerator.synchronize()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]