[Feature] ViT Full CUDA Graph (#35963)
Signed-off-by: Baorun Mu <bmu@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
1f0d210641
commit
f85e479e66
@@ -13,6 +13,7 @@ from collections.abc import (
|
||||
from contextlib import ExitStack, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Protocol,
|
||||
@@ -46,6 +47,11 @@ if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.multimodal.registry import _ProcessorFactories
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
EncoderCudaGraphConfig,
|
||||
EncoderCudaGraphReplayBuffers,
|
||||
)
|
||||
else:
|
||||
VllmConfig = object
|
||||
WeightsMapper = object
|
||||
@@ -1494,3 +1500,138 @@ def supports_xdrope(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
|
||||
return isinstance(model, SupportsXDRoPE)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEncoderCudaGraph(Protocol):
|
||||
"""Interface for models whose vision encoder supports CUDA graph
|
||||
capture/replay.
|
||||
|
||||
Models implement these methods to provide the
|
||||
:class:`EncoderCudaGraphManager` with all model-specific logic
|
||||
(input handling, metadata computation, forward pass) without the
|
||||
manager needing to know model internals.
|
||||
"""
|
||||
|
||||
supports_encoder_cudagraph: ClassVar[Literal[True]] = True
|
||||
|
||||
def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ...
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[int, int]:
|
||||
"""Return (min_token_budget, max_token_budget) for auto-inference.
|
||||
|
||||
- min_token_budget: estimated smallest possible encoder input
|
||||
(e.g. 64 for a 224x224 image)
|
||||
- max_token_budget: estimated largest budget worth capturing
|
||||
(e.g. max_num_batched_tokens)
|
||||
|
||||
Used when ``encoder_cudagraph_token_budgets`` and/or
|
||||
``encoder_cudagraph_max_images_per_batch`` are not explicitly
|
||||
specified by the user.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
"""Return the number of items (e.g. images) in the batch."""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""Return output token count for each item.
|
||||
|
||||
Used for greedy packing and DP load balancing.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""Return input size (e.g. patch count) for each item.
|
||||
|
||||
Used for input tensor slicing offsets.
|
||||
"""
|
||||
...
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
"""Select a subset of items and return mm_kwargs for the sub-batch.
|
||||
|
||||
Called by the manager during greedy packing and DP sharding to
|
||||
extract inputs for a specific set of items (e.g. images at
|
||||
indices [0, 3, 5]). The implementation is model-specific
|
||||
because input formats differ:
|
||||
|
||||
- Qwen-family: slice concatenated pixel_values by cumulative
|
||||
patch offsets, subset grid_thw by indices.
|
||||
- Batched models (CLIP): index pixel_values along dim 0.
|
||||
"""
|
||||
...
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> "EncoderCudaGraphCaptureInputs":
|
||||
"""Create dummy inputs and buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
) -> "EncoderCudaGraphReplayBuffers":
|
||||
"""Compute buffer values from actual batch inputs for replay."""
|
||||
...
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Run the encoder forward pass with precomputed buffers.
|
||||
|
||||
Used during both CUDA graph capture and replay.
|
||||
"""
|
||||
...
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
"""Run the encoder forward pass without precomputed buffers.
|
||||
|
||||
Used as eager fallback when inputs exceed all budgets.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_encoder_cudagraph(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsEncoderCudaGraph]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_encoder_cudagraph(
|
||||
model: object,
|
||||
) -> TypeIs[SupportsEncoderCudaGraph]: ...
|
||||
|
||||
|
||||
def supports_encoder_cudagraph(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsEncoderCudaGraph]] | TypeIs[SupportsEncoderCudaGraph]:
|
||||
return isinstance(model, SupportsEncoderCudaGraph)
|
||||
|
||||
@@ -103,6 +103,7 @@ from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
@@ -528,54 +529,120 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def forward(
|
||||
def prepare_encoder_metadata(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
grid_thw_list: list[list[int]],
|
||||
*,
|
||||
max_batch_size: int | None = None,
|
||||
max_seqlen_override: int | None = None,
|
||||
device: torch.device | None = None,
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
"""Compute encoder metadata from grid_thw_list.
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_thw = grid_thw.numpy()
|
||||
Shared by the eager forward path, CUDA graph capture, and
|
||||
CUDA graph replay to avoid duplicated implementation.
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||
Args:
|
||||
grid_thw_list: Grid configurations as list of [t, h, w].
|
||||
max_batch_size: If set, pad cu_seqlens to this size
|
||||
(needed for CUDA graph capture/replay).
|
||||
max_seqlen_override: If set, use this value for max_seqlen
|
||||
instead of computing from cu_seqlens (needed for CUDA
|
||||
graph capture to cover worst-case replay scenarios).
|
||||
device: Device to place tensors on. Defaults to self.device.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
axis=0, dtype=np.int32
|
||||
metadata: dict[str, torch.Tensor | None] = {}
|
||||
|
||||
# Positional embeddings
|
||||
metadata["pos_embeds"] = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list)
|
||||
metadata["rotary_pos_emb_cos"] = rotary_cos
|
||||
metadata["rotary_pos_emb_sin"] = rotary_sin
|
||||
|
||||
# cu_seqlens from grid_thw
|
||||
grid_thw_np = np.array(grid_thw_list, dtype=np.int32)
|
||||
patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2]
|
||||
cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum(
|
||||
dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, self.device
|
||||
|
||||
# Pad cu_seqlens if max_batch_size specified
|
||||
if max_batch_size is not None:
|
||||
num_seqs = len(cu_seqlens) - 1
|
||||
if num_seqs < max_batch_size:
|
||||
cu_seqlens = np.concatenate(
|
||||
[
|
||||
cu_seqlens,
|
||||
np.full(
|
||||
max_batch_size - num_seqs,
|
||||
cu_seqlens[-1],
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# sequence_lengths (backend-specific)
|
||||
metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, device
|
||||
)
|
||||
max_seqlen = torch.tensor(
|
||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
|
||||
# max_seqlen
|
||||
if max_seqlen_override is not None:
|
||||
max_seqlen_val = max_seqlen_override
|
||||
else:
|
||||
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||
self.attn_backend, cu_seqlens
|
||||
)
|
||||
# Keep max_seqlen on CPU: attention wrappers call .item() on it,
|
||||
# and having it on GPU would capture a wasteful D2H copy in CUDA
|
||||
# graphs without changing behavior (the scalar is baked at capture).
|
||||
metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32)
|
||||
|
||||
# Recompute cu_seqlens (backend-specific transformation)
|
||||
metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
self.attn_backend,
|
||||
cu_seqlens,
|
||||
self.hidden_size,
|
||||
self.tp_size,
|
||||
self.device,
|
||||
device,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
*,
|
||||
encoder_metadata: dict[str, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
if encoder_metadata is None:
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
encoder_metadata = self.prepare_encoder_metadata(grid_thw_list)
|
||||
|
||||
pos_embeds = encoder_metadata["pos_embeds"]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cu_seqlens=encoder_metadata["cu_seqlens"],
|
||||
rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"],
|
||||
rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"],
|
||||
max_seqlen=encoder_metadata["max_seqlen"],
|
||||
sequence_lengths=encoder_metadata.get("sequence_lengths"),
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
@@ -1358,6 +1425,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
|
||||
class Qwen3VLForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
@@ -1507,6 +1575,178 @@ class Qwen3VLForConditionalGeneration(
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
|
||||
# -- SupportsEncoderCudaGraph protocol methods --
|
||||
|
||||
def get_encoder_cudagraph_config(self):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphConfig,
|
||||
)
|
||||
|
||||
return EncoderCudaGraphConfig(
|
||||
modalities=["image"],
|
||||
input_key="pixel_values",
|
||||
buffer_keys=[
|
||||
"pos_embeds",
|
||||
"rotary_pos_emb_cos",
|
||||
"rotary_pos_emb_sin",
|
||||
"cu_seqlens",
|
||||
"max_seqlen",
|
||||
"sequence_lengths",
|
||||
],
|
||||
out_hidden_size=self.visual.out_hidden_size,
|
||||
)
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config,
|
||||
) -> tuple[int, int]:
|
||||
# Min: estimated smallest possible encoder input.
|
||||
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
|
||||
min_budget = 64
|
||||
# Max: capped by max_num_batched_tokens
|
||||
max_budget = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
return (min_budget, max_budget)
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
return len(mm_kwargs["image_grid_thw"])
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
m = self.visual.spatial_merge_size
|
||||
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
|
||||
if len(indices) == 0:
|
||||
return {
|
||||
"pixel_values": pixel_values[:0],
|
||||
"image_grid_thw": [],
|
||||
}
|
||||
|
||||
# Compute cumulative patch offsets for slicing pixel_values
|
||||
patches_per_item = [t * h * w for t, h, w in grid_thw]
|
||||
cum_patches = [0]
|
||||
for p in patches_per_item:
|
||||
cum_patches.append(cum_patches[-1] + p)
|
||||
|
||||
selected_pv = torch.cat(
|
||||
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
|
||||
)
|
||||
selected_grid = [grid_thw[i] for i in indices]
|
||||
|
||||
return {
|
||||
"pixel_values": selected_pv,
|
||||
"image_grid_thw": selected_grid,
|
||||
}
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
)
|
||||
|
||||
spatial_merge_size = self.visual.spatial_merge_size
|
||||
per_image_output = token_budget // max_batch_size
|
||||
|
||||
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
|
||||
# produces exactly per_image_output tokens per image.
|
||||
grid_config = [
|
||||
[1, spatial_merge_size, per_image_output * spatial_merge_size]
|
||||
for _ in range(max_batch_size)
|
||||
]
|
||||
|
||||
# Create dummy pixel_values
|
||||
patch_embed = self.visual.patch_embed
|
||||
in_channels = patch_embed.proj.in_channels
|
||||
patch_size = patch_embed.patch_size
|
||||
temporal_patch_size = patch_embed.temporal_patch_size
|
||||
total_patches = sum(t * h * w for t, h, w in grid_config)
|
||||
flattened_patch_size = (
|
||||
in_channels * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
dummy_pixel_values = torch.randn(
|
||||
total_patches, flattened_patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Override max_seqlen with a safe upper bound for capture.
|
||||
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
|
||||
# so the capture value must cover any replay scenario.
|
||||
# Worst case: 1 image consuming the full budget ->
|
||||
# seq_len = token_budget * spatial_merge_size^2.
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seqlen_override=token_budget * (spatial_merge_size**2),
|
||||
device=device,
|
||||
)
|
||||
|
||||
mm_kwargs = {
|
||||
"pixel_values": dummy_pixel_values,
|
||||
"image_grid_thw": grid_config,
|
||||
}
|
||||
|
||||
return EncoderCudaGraphCaptureInputs(
|
||||
mm_kwargs=mm_kwargs,
|
||||
buffers=buffers,
|
||||
)
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphReplayBuffers,
|
||||
)
|
||||
|
||||
grid_thw_list = mm_kwargs["image_grid_thw"]
|
||||
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_thw_list,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
return EncoderCudaGraphReplayBuffers(buffers=buffers)
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
return self.visual(pixel_values, grid_thw)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5_VLImageInputs | None:
|
||||
|
||||
Reference in New Issue
Block a user