[Feature] ViT Full CUDA Graph (#35963)

Signed-off-by: Baorun Mu <bmu@nvidia.com>
This commit is contained in:
Baorun (Lauren) Mu
2026-03-23 01:01:10 -04:00
committed by GitHub
parent 1f0d210641
commit f85e479e66
7 changed files with 1584 additions and 31 deletions

View File

@@ -13,6 +13,7 @@ from collections.abc import (
from contextlib import ExitStack, contextmanager, nullcontext
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
Protocol,
@@ -46,6 +47,11 @@ if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.multimodal.registry import _ProcessorFactories
from vllm.sequence import IntermediateTensors
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
EncoderCudaGraphConfig,
EncoderCudaGraphReplayBuffers,
)
else:
VllmConfig = object
WeightsMapper = object
@@ -1494,3 +1500,138 @@ def supports_xdrope(
model: type[object] | object,
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
return isinstance(model, SupportsXDRoPE)
@runtime_checkable
class SupportsEncoderCudaGraph(Protocol):
"""Interface for models whose vision encoder supports CUDA graph
capture/replay.
Models implement these methods to provide the
:class:`EncoderCudaGraphManager` with all model-specific logic
(input handling, metadata computation, forward pass) without the
manager needing to know model internals.
"""
supports_encoder_cudagraph: ClassVar[Literal[True]] = True
def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ...
def get_encoder_cudagraph_budget_range(
self,
vllm_config: "VllmConfig",
) -> tuple[int, int]:
"""Return (min_token_budget, max_token_budget) for auto-inference.
- min_token_budget: estimated smallest possible encoder input
(e.g. 64 for a 224x224 image)
- max_token_budget: estimated largest budget worth capturing
(e.g. max_num_batched_tokens)
Used when ``encoder_cudagraph_token_budgets`` and/or
``encoder_cudagraph_max_images_per_batch`` are not explicitly
specified by the user.
"""
...
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
"""Return the number of items (e.g. images) in the batch."""
...
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
"""Return output token count for each item.
Used for greedy packing and DP load balancing.
"""
...
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
"""Return input size (e.g. patch count) for each item.
Used for input tensor slicing offsets.
"""
...
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
"""Select a subset of items and return mm_kwargs for the sub-batch.
Called by the manager during greedy packing and DP sharding to
extract inputs for a specific set of items (e.g. images at
indices [0, 3, 5]). The implementation is model-specific
because input formats differ:
- Qwen-family: slice concatenated pixel_values by cumulative
patch offsets, subset grid_thw by indices.
- Batched models (CLIP): index pixel_values along dim 0.
"""
...
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
) -> "EncoderCudaGraphCaptureInputs":
"""Create dummy inputs and buffers for CUDA graph capture."""
...
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
) -> "EncoderCudaGraphReplayBuffers":
"""Compute buffer values from actual batch inputs for replay."""
...
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Run the encoder forward pass with precomputed buffers.
Used during both CUDA graph capture and replay.
"""
...
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
"""Run the encoder forward pass without precomputed buffers.
Used as eager fallback when inputs exceed all budgets.
"""
...
@overload
def supports_encoder_cudagraph(
model: type[object],
) -> TypeIs[type[SupportsEncoderCudaGraph]]: ...
@overload
def supports_encoder_cudagraph(
model: object,
) -> TypeIs[SupportsEncoderCudaGraph]: ...
def supports_encoder_cudagraph(
model: type[object] | object,
) -> TypeIs[type[SupportsEncoderCudaGraph]] | TypeIs[SupportsEncoderCudaGraph]:
return isinstance(model, SupportsEncoderCudaGraph)

View File

@@ -103,6 +103,7 @@ from .interfaces import (
MultiModalEmbeddings,
SupportsEagle,
SupportsEagle3,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
@@ -528,54 +529,120 @@ class Qwen3_VisionTransformer(nn.Module):
return torch.cat(outputs, dim=0)
def forward(
def prepare_encoder_metadata(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
grid_thw_list: list[list[int]],
*,
max_batch_size: int | None = None,
max_seqlen_override: int | None = None,
device: torch.device | None = None,
) -> dict[str, torch.Tensor | None]:
"""Compute encoder metadata from grid_thw_list.
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
grid_thw = np.array(grid_thw, dtype=np.int32)
else:
grid_thw_list = grid_thw.tolist()
grid_thw = grid_thw.numpy()
Shared by the eager forward path, CUDA graph capture, and
CUDA graph replay to avoid duplicated implementation.
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
hidden_states = hidden_states + pos_embeds
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
Args:
grid_thw_list: Grid configurations as list of [t, h, w].
max_batch_size: If set, pad cu_seqlens to this size
(needed for CUDA graph capture/replay).
max_seqlen_override: If set, use this value for max_seqlen
instead of computing from cu_seqlens (needed for CUDA
graph capture to cover worst-case replay scenarios).
device: Device to place tensors on. Defaults to self.device.
"""
if device is None:
device = self.device
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype=np.int32
metadata: dict[str, torch.Tensor | None] = {}
# Positional embeddings
metadata["pos_embeds"] = self.fast_pos_embed_interpolate(grid_thw_list)
rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list)
metadata["rotary_pos_emb_cos"] = rotary_cos
metadata["rotary_pos_emb_sin"] = rotary_sin
# cu_seqlens from grid_thw
grid_thw_np = np.array(grid_thw_list, dtype=np.int32)
patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2]
cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum(
dtype=np.int32
)
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, self.device
# Pad cu_seqlens if max_batch_size specified
if max_batch_size is not None:
num_seqs = len(cu_seqlens) - 1
if num_seqs < max_batch_size:
cu_seqlens = np.concatenate(
[
cu_seqlens,
np.full(
max_batch_size - num_seqs,
cu_seqlens[-1],
dtype=np.int32,
),
]
)
# sequence_lengths (backend-specific)
metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens(
self.attn_backend, cu_seqlens, device
)
max_seqlen = torch.tensor(
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
dtype=torch.int32,
)
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
# max_seqlen
if max_seqlen_override is not None:
max_seqlen_val = max_seqlen_override
else:
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
self.attn_backend, cu_seqlens
)
# Keep max_seqlen on CPU: attention wrappers call .item() on it,
# and having it on GPU would capture a wasteful D2H copy in CUDA
# graphs without changing behavior (the scalar is baked at capture).
metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32)
# Recompute cu_seqlens (backend-specific transformation)
metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens(
self.attn_backend,
cu_seqlens,
self.hidden_size,
self.tp_size,
self.device,
device,
)
return metadata
def forward(
self,
x: torch.Tensor,
grid_thw: torch.Tensor | list[list[int]],
*,
encoder_metadata: dict[str, torch.Tensor] | None = None,
) -> torch.Tensor:
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
hidden_states = self.patch_embed(hidden_states)
if encoder_metadata is None:
if isinstance(grid_thw, list):
grid_thw_list = grid_thw
else:
grid_thw_list = grid_thw.tolist()
encoder_metadata = self.prepare_encoder_metadata(grid_thw_list)
pos_embeds = encoder_metadata["pos_embeds"]
hidden_states = hidden_states + pos_embeds
hidden_states = hidden_states.unsqueeze(1)
deepstack_feature_lists = []
for layer_num, blk in enumerate(self.blocks):
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb_cos=rotary_pos_emb_cos,
rotary_pos_emb_sin=rotary_pos_emb_sin,
max_seqlen=max_seqlen,
sequence_lengths=sequence_lengths,
cu_seqlens=encoder_metadata["cu_seqlens"],
rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"],
rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"],
max_seqlen=encoder_metadata["max_seqlen"],
sequence_lengths=encoder_metadata.get("sequence_lengths"),
)
if layer_num in self.deepstack_visual_indexes:
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
@@ -1358,6 +1425,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
class Qwen3VLForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsEncoderCudaGraph,
SupportsLoRA,
SupportsPP,
SupportsMRoPE,
@@ -1507,6 +1575,178 @@ class Qwen3VLForConditionalGeneration(
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
# -- SupportsEncoderCudaGraph protocol methods --
def get_encoder_cudagraph_config(self):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphConfig,
)
return EncoderCudaGraphConfig(
modalities=["image"],
input_key="pixel_values",
buffer_keys=[
"pos_embeds",
"rotary_pos_emb_cos",
"rotary_pos_emb_sin",
"cu_seqlens",
"max_seqlen",
"sequence_lengths",
],
out_hidden_size=self.visual.out_hidden_size,
)
def get_encoder_cudagraph_budget_range(
self,
vllm_config,
) -> tuple[int, int]:
# Min: estimated smallest possible encoder input.
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
min_budget = 64
# Max: capped by max_num_batched_tokens
max_budget = vllm_config.scheduler_config.max_num_batched_tokens
return (min_budget, max_budget)
def get_encoder_cudagraph_num_items(
self,
mm_kwargs: dict[str, Any],
) -> int:
return len(mm_kwargs["image_grid_thw"])
def get_encoder_cudagraph_per_item_output_tokens(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
m = self.visual.spatial_merge_size
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
def get_encoder_cudagraph_per_item_input_sizes(
self,
mm_kwargs: dict[str, Any],
) -> list[int]:
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
def select_encoder_cudagraph_items(
self,
mm_kwargs: dict[str, Any],
indices: list[int],
) -> dict[str, Any]:
grid_thw = mm_kwargs["image_grid_thw"]
pixel_values = mm_kwargs["pixel_values"]
if len(indices) == 0:
return {
"pixel_values": pixel_values[:0],
"image_grid_thw": [],
}
# Compute cumulative patch offsets for slicing pixel_values
patches_per_item = [t * h * w for t, h, w in grid_thw]
cum_patches = [0]
for p in patches_per_item:
cum_patches.append(cum_patches[-1] + p)
selected_pv = torch.cat(
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
)
selected_grid = [grid_thw[i] for i in indices]
return {
"pixel_values": selected_pv,
"image_grid_thw": selected_grid,
}
def prepare_encoder_cudagraph_capture_inputs(
self,
token_budget: int,
max_batch_size: int,
device: torch.device,
dtype: torch.dtype,
):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphCaptureInputs,
)
spatial_merge_size = self.visual.spatial_merge_size
per_image_output = token_budget // max_batch_size
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
# produces exactly per_image_output tokens per image.
grid_config = [
[1, spatial_merge_size, per_image_output * spatial_merge_size]
for _ in range(max_batch_size)
]
# Create dummy pixel_values
patch_embed = self.visual.patch_embed
in_channels = patch_embed.proj.in_channels
patch_size = patch_embed.patch_size
temporal_patch_size = patch_embed.temporal_patch_size
total_patches = sum(t * h * w for t, h, w in grid_config)
flattened_patch_size = (
in_channels * temporal_patch_size * patch_size * patch_size
)
dummy_pixel_values = torch.randn(
total_patches, flattened_patch_size, device=device, dtype=dtype
)
# Override max_seqlen with a safe upper bound for capture.
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
# so the capture value must cover any replay scenario.
# Worst case: 1 image consuming the full budget ->
# seq_len = token_budget * spatial_merge_size^2.
buffers = self.visual.prepare_encoder_metadata(
grid_config,
max_batch_size=max_batch_size,
max_seqlen_override=token_budget * (spatial_merge_size**2),
device=device,
)
mm_kwargs = {
"pixel_values": dummy_pixel_values,
"image_grid_thw": grid_config,
}
return EncoderCudaGraphCaptureInputs(
mm_kwargs=mm_kwargs,
buffers=buffers,
)
def prepare_encoder_cudagraph_replay_buffers(
self,
mm_kwargs: dict[str, Any],
max_batch_size: int,
):
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
EncoderCudaGraphReplayBuffers,
)
grid_thw_list = mm_kwargs["image_grid_thw"]
buffers = self.visual.prepare_encoder_metadata(
grid_thw_list,
max_batch_size=max_batch_size,
)
return EncoderCudaGraphReplayBuffers(buffers=buffers)
def encoder_cudagraph_forward(
self,
mm_kwargs: dict[str, Any],
buffers: dict[str, torch.Tensor],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
def encoder_eager_forward(
self,
mm_kwargs: dict[str, Any],
) -> torch.Tensor:
pixel_values = mm_kwargs["pixel_values"]
grid_thw = mm_kwargs["image_grid_thw"]
return self.visual(pixel_values, grid_thw)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None: