[Feature] ViT Full CUDA Graph (#35963)
Signed-off-by: Baorun Mu <bmu@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
1f0d210641
commit
f85e479e66
451
tests/v1/cudagraph/test_encoder_cudagraph.py
Normal file
451
tests/v1/cudagraph/test_encoder_cudagraph.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Unit tests for EncoderCudaGraphManager.
|
||||
|
||||
Test organization:
|
||||
No GPU required:
|
||||
- TestFindBudgetGraph — greedy budget selection logic
|
||||
- TestGetCumulativeStats — hit/miss rate statistics
|
||||
GPU required:
|
||||
- TestEncoderCudaGraphCaptureReplay — capture, replay, fallback, counters, chunking
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
|
||||
EncoderCudaGraphManager,
|
||||
)
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
EncoderCudaGraphConfig,
|
||||
EncoderCudaGraphReplayBuffers,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_manager_with_budgets(budgets: list[int]) -> EncoderCudaGraphManager:
|
||||
"""Create a minimal EncoderCudaGraphManager with only token_budgets set.
|
||||
|
||||
Skips the parts of __init__ that require a real VllmConfig / model
|
||||
by patching the attributes directly after construction.
|
||||
"""
|
||||
mgr = object.__new__(EncoderCudaGraphManager)
|
||||
mgr.token_budgets = sorted(budgets)
|
||||
mgr.max_batch_size = 16
|
||||
mgr.use_dp = False
|
||||
mgr.budget_graphs = {}
|
||||
mgr.graph_hits = 0
|
||||
mgr.graph_misses = 0
|
||||
mgr.log_stats_interval = 100
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _generate_budgets
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGenerateBudgets:
|
||||
"""Auto-generate power-of-2 budgets from min to max."""
|
||||
|
||||
def test_exact_powers_of_2(self):
|
||||
result = EncoderCudaGraphManager._generate_budgets(64, 1024)
|
||||
assert result == [64, 128, 256, 512, 1024]
|
||||
|
||||
def test_max_not_power_of_2(self):
|
||||
result = EncoderCudaGraphManager._generate_budgets(64, 800)
|
||||
assert result == [64, 128, 256, 512, 800]
|
||||
|
||||
def test_min_equals_max(self):
|
||||
result = EncoderCudaGraphManager._generate_budgets(64, 64)
|
||||
assert result == [64]
|
||||
|
||||
def test_large_range(self):
|
||||
result = EncoderCudaGraphManager._generate_budgets(64, 8192)
|
||||
assert result == [64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _find_smallest_fitting_budget_given_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFindBudgetGraph:
|
||||
"""Budget greedy selection: smallest budget >= total_tokens."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"total_tokens,budgets,expected",
|
||||
[
|
||||
# Exact match
|
||||
(2048, [2048, 4096, 8192], 2048),
|
||||
# Below smallest budget — picks smallest
|
||||
(100, [2048, 4096, 8192], 2048),
|
||||
# Zero tokens — picks smallest
|
||||
(0, [2048, 4096, 8192], 2048),
|
||||
# Between budgets — picks next one up
|
||||
(2049, [2048, 4096, 8192], 4096),
|
||||
(4097, [2048, 4096, 8192], 8192),
|
||||
# Exceeds all budgets — returns None (eager fallback)
|
||||
(9000, [2048, 4096, 8192], None),
|
||||
# Single budget, fits
|
||||
(1000, [2048], 2048),
|
||||
# Single budget, does not fit
|
||||
(3000, [2048], None),
|
||||
],
|
||||
)
|
||||
def test_find_budget(self, total_tokens, budgets, expected):
|
||||
mgr = _make_manager_with_budgets(budgets)
|
||||
result = mgr._find_smallest_fitting_budget_given_tokens(total_tokens)
|
||||
assert result == expected
|
||||
|
||||
def test_budgets_are_sorted(self):
|
||||
"""Manager always sorts budgets ascending at init."""
|
||||
mgr = _make_manager_with_budgets([8192, 2048, 4096])
|
||||
assert mgr.token_budgets == [2048, 4096, 8192]
|
||||
# Budget selection still works correctly after sorting
|
||||
assert mgr._find_smallest_fitting_budget_given_tokens(3000) == 4096
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cumulative_stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetCumulativeStats:
|
||||
"""Statistics tracking and reporting."""
|
||||
|
||||
def test_initial_stats_are_zero(self):
|
||||
mgr = _make_manager_with_budgets([2048])
|
||||
stats = mgr.get_cumulative_stats()
|
||||
assert stats["graph_hits"] == 0
|
||||
assert stats["graph_misses"] == 0
|
||||
assert stats["hit_rate"] == 0.0
|
||||
|
||||
def test_hit_rate_calculation(self):
|
||||
mgr = _make_manager_with_budgets([2048])
|
||||
mgr.graph_hits = 75
|
||||
mgr.graph_misses = 25
|
||||
stats = mgr.get_cumulative_stats()
|
||||
assert stats["graph_hits"] == 75
|
||||
assert stats["graph_misses"] == 25
|
||||
assert stats["hit_rate"] == pytest.approx(0.75)
|
||||
|
||||
def test_all_hits(self):
|
||||
mgr = _make_manager_with_budgets([2048])
|
||||
mgr.graph_hits = 100
|
||||
mgr.graph_misses = 0
|
||||
assert mgr.get_cumulative_stats()["hit_rate"] == pytest.approx(1.0)
|
||||
|
||||
def test_all_misses(self):
|
||||
mgr = _make_manager_with_budgets([2048])
|
||||
mgr.graph_hits = 0
|
||||
mgr.graph_misses = 50
|
||||
assert mgr.get_cumulative_stats()["hit_rate"] == pytest.approx(0.0)
|
||||
|
||||
def test_stats_report_budget_info(self):
|
||||
budgets = [2048, 4096, 8192]
|
||||
mgr = _make_manager_with_budgets(budgets)
|
||||
stats = mgr.get_cumulative_stats()
|
||||
assert stats["num_budgets"] == 0 # no graphs captured yet
|
||||
assert stats["token_budgets"] == budgets
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPU fixtures and helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Mock encoder parameters (kept small for fast capture)
|
||||
_SPATIAL_MERGE = 2
|
||||
_HIDDEN = 32
|
||||
_PATCH_SIZE = 4 # H/W per patch in grid_thw units
|
||||
_TEMPORAL_PATCH = 1
|
||||
_IN_CHANNELS = 3
|
||||
# flattened_patch_size = in_channels * temporal_patch * patch_size^2
|
||||
_FLAT = _IN_CHANNELS * _TEMPORAL_PATCH * _PATCH_SIZE * _PATCH_SIZE # 48
|
||||
|
||||
# Test budgets: small to keep capture fast
|
||||
_BUDGETS = [16, 64]
|
||||
_MAX_BATCH = 4
|
||||
|
||||
|
||||
def _count_input_patches(grid_thw_list: list[list[int]]) -> int:
|
||||
return sum(t * h * w for t, h, w in grid_thw_list)
|
||||
|
||||
|
||||
def _count_output_tokens(
|
||||
grid_thw_list: list[list[int]], spatial_merge_size: int
|
||||
) -> int:
|
||||
m = spatial_merge_size
|
||||
return sum(t * (h // m) * (w // m) for t, h, w in grid_thw_list)
|
||||
|
||||
|
||||
class SimpleMockViTModel(torch.nn.Module):
|
||||
"""Minimal ViT model for CUDA graph tests.
|
||||
|
||||
Implements the SupportsEncoderCudaGraph protocol by providing
|
||||
all required methods. The forward pass projects patches and
|
||||
simulates spatial merge by averaging groups of m^2 patches.
|
||||
"""
|
||||
|
||||
supports_encoder_cudagraph = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Linear(_FLAT, _HIDDEN)
|
||||
self.spatial_merge_size = _SPATIAL_MERGE
|
||||
self.out_hidden_size = _HIDDEN
|
||||
|
||||
def get_encoder_cudagraph_config(self) -> EncoderCudaGraphConfig:
|
||||
return EncoderCudaGraphConfig(
|
||||
modalities=["image"],
|
||||
input_key="pixel_values",
|
||||
buffer_keys=["dummy_buf"],
|
||||
out_hidden_size=_HIDDEN,
|
||||
)
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config,
|
||||
) -> tuple[int, int]:
|
||||
# For tests: min=4, max=128 (small values for fast capture)
|
||||
return (4, 128)
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
return len(mm_kwargs["image_grid_thw"])
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
m = _SPATIAL_MERGE
|
||||
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
|
||||
if len(indices) == 0:
|
||||
return {
|
||||
"pixel_values": pixel_values[:0],
|
||||
"image_grid_thw": [],
|
||||
}
|
||||
|
||||
patches_per_item = [t * h * w for t, h, w in grid_thw]
|
||||
cum_patches = [0]
|
||||
for p in patches_per_item:
|
||||
cum_patches.append(cum_patches[-1] + p)
|
||||
|
||||
selected_pv = torch.cat(
|
||||
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
|
||||
)
|
||||
selected_grid = [grid_thw[i] for i in indices]
|
||||
return {
|
||||
"pixel_values": selected_pv,
|
||||
"image_grid_thw": selected_grid,
|
||||
}
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> EncoderCudaGraphCaptureInputs:
|
||||
per_image_output = token_budget // max_batch_size
|
||||
grid_config = [
|
||||
[1, _SPATIAL_MERGE, per_image_output * _SPATIAL_MERGE]
|
||||
for _ in range(max_batch_size)
|
||||
]
|
||||
total_patches = _count_input_patches(grid_config)
|
||||
dummy_pixel_values = torch.randn(
|
||||
total_patches, _FLAT, device=device, dtype=dtype
|
||||
)
|
||||
n_out = _count_output_tokens(grid_config, _SPATIAL_MERGE)
|
||||
dummy_buf = torch.zeros(n_out, _HIDDEN, device=device, dtype=dtype)
|
||||
return EncoderCudaGraphCaptureInputs(
|
||||
mm_kwargs={
|
||||
"pixel_values": dummy_pixel_values,
|
||||
"image_grid_thw": grid_config,
|
||||
},
|
||||
buffers={"dummy_buf": dummy_buf},
|
||||
)
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
) -> EncoderCudaGraphReplayBuffers:
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
n_out = _count_output_tokens(grid_thw, _SPATIAL_MERGE)
|
||||
p = next(self.parameters())
|
||||
dummy_buf = torch.zeros(n_out, _HIDDEN, device=p.device, dtype=p.dtype)
|
||||
return EncoderCudaGraphReplayBuffers(buffers={"dummy_buf": dummy_buf})
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return self._forward(mm_kwargs["pixel_values"])
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
return self._forward(mm_kwargs["pixel_values"])
|
||||
|
||||
def _forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
m2 = _SPATIAL_MERGE**2
|
||||
out = self.proj(pixel_values)
|
||||
n_out = out.shape[0] // m2
|
||||
return out[: n_out * m2].view(n_out, m2, _HIDDEN).mean(dim=1)
|
||||
|
||||
|
||||
def _make_manager_for_gpu(
|
||||
model: SimpleMockViTModel,
|
||||
token_budgets: list[int],
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> EncoderCudaGraphManager:
|
||||
"""Create EncoderCudaGraphManager bypassing VllmConfig for GPU tests."""
|
||||
mgr = object.__new__(EncoderCudaGraphManager)
|
||||
mgr.token_budgets = sorted(token_budgets)
|
||||
mgr.max_batch_size = max_batch_size
|
||||
mgr.use_dp = False
|
||||
mgr.budget_graphs = {}
|
||||
mgr.graph_hits = 0
|
||||
mgr.graph_misses = 0
|
||||
mgr.log_stats_interval = 100
|
||||
mgr.model = model
|
||||
mgr.config = model.get_encoder_cudagraph_config()
|
||||
mgr.device = device
|
||||
mgr.dtype = dtype
|
||||
return mgr
|
||||
|
||||
|
||||
def _make_pixel_values(
|
||||
grid_thw_list: list[list[int]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""Random pixel_values matching the total input patch count."""
|
||||
n = _count_input_patches(grid_thw_list)
|
||||
return torch.randn(n, _FLAT, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _make_mm_kwargs(
|
||||
grid_thw_list: list[list[int]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> dict[str, Any]:
|
||||
"""Create mm_kwargs for testing."""
|
||||
return {
|
||||
"pixel_values": _make_pixel_values(grid_thw_list, device, dtype),
|
||||
"image_grid_thw": grid_thw_list,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GPU tests — capture, replay, fallback, counters, chunking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
class TestEncoderCudaGraphCaptureReplay:
|
||||
def setup_method(self):
|
||||
self.device = torch.device("cuda:0")
|
||||
self.dtype = torch.float16
|
||||
self.model = SimpleMockViTModel().to(self.device).half()
|
||||
self.mgr = _make_manager_for_gpu(
|
||||
self.model, _BUDGETS, _MAX_BATCH, self.device, self.dtype
|
||||
)
|
||||
self.mgr.capture()
|
||||
|
||||
# --- capture ---
|
||||
|
||||
def test_capture_creates_one_graph_per_budget(self):
|
||||
assert len(self.mgr.budget_graphs) == len(_BUDGETS)
|
||||
assert set(self.mgr.budget_graphs.keys()) == set(_BUDGETS)
|
||||
|
||||
# --- output shape ---
|
||||
|
||||
def test_execute_returns_one_tensor_per_image(self):
|
||||
grid_thw = [[1, 4, 4], [1, 4, 4]]
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
result = self.mgr.execute(mm_kwargs)
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
def test_execute_output_tokens_per_image(self):
|
||||
# [1,4,4] → 1*(4//2)*(4//2) = 4 tokens; [1,8,8] → 16 tokens
|
||||
grid_thw = [[1, 4, 4], [1, 8, 8]]
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
result = self.mgr.execute(mm_kwargs)
|
||||
assert result is not None
|
||||
assert result[0].shape == (4, _HIDDEN)
|
||||
assert result[1].shape == (16, _HIDDEN)
|
||||
|
||||
# --- budget fallback ---
|
||||
|
||||
def test_eager_fallback_when_tokens_exceed_all_budgets(self):
|
||||
# [1,18,18] → 1*(18//2)*(18//2) = 81 tokens > max budget 64.
|
||||
# Greedy packing handles the fallback internally: the oversized image
|
||||
# gets an eager forward pass and is returned as part of the output list
|
||||
# (execute() no longer returns None for individual image misses).
|
||||
grid_thw = [[1, 18, 18]]
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
result = self.mgr.execute(mm_kwargs)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
# Eager output: SimpleMockViTModel produces n_out = 81 tokens
|
||||
assert result[0].shape == (81, _HIDDEN)
|
||||
assert self.mgr.graph_misses == 1
|
||||
|
||||
# --- counters ---
|
||||
|
||||
def test_hit_counter_increments_by_num_images(self):
|
||||
grid_thw = [[1, 4, 4], [1, 4, 4]]
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
self.mgr.execute(mm_kwargs)
|
||||
assert self.mgr.graph_hits == 2
|
||||
|
||||
def test_miss_counter_increments_by_num_images(self):
|
||||
grid_thw = [[1, 18, 18]] # 81 tokens > 64
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
self.mgr.execute(mm_kwargs)
|
||||
assert self.mgr.graph_misses == 1
|
||||
|
||||
# --- chunking ---
|
||||
|
||||
def test_chunking_when_images_exceed_max_batch(self):
|
||||
# 8 images > max_batch_size=4 → 2 chunks of 4
|
||||
# each chunk: 4 * 4 = 16 tokens → fits budget 16
|
||||
n_images = _MAX_BATCH * 2
|
||||
grid_thw = [[1, 4, 4]] * n_images
|
||||
mm_kwargs = _make_mm_kwargs(grid_thw, self.device, self.dtype)
|
||||
result = self.mgr.execute(mm_kwargs)
|
||||
assert result is not None
|
||||
assert len(result) == n_images
|
||||
for out in result:
|
||||
assert out.shape == (4, _HIDDEN)
|
||||
@@ -489,6 +489,28 @@ class CompilationConfig:
|
||||
on selected platforms. Disabled by default until more models
|
||||
are supported/tested to work."""
|
||||
|
||||
# Vision encoder CUDA graph
|
||||
cudagraph_mm_encoder: bool = False
|
||||
"""Enable CUDA graph capture for multimodal encoder (ViT).
|
||||
When enabled, captures full encoder forward as CUDA graph
|
||||
for each token budget level."""
|
||||
|
||||
encoder_cudagraph_token_budgets: list[int] = field(default_factory=list)
|
||||
"""Token budget levels for encoder CUDA graph capture.
|
||||
Each budget defines a fixed token capacity. At runtime, images are greedy-packed
|
||||
into the smallest fitting budget and the corresponding CUDA graph is replayed.
|
||||
If empty (default), auto-inferred from model architecture as power-of-2
|
||||
levels from the model's estimated min budget to max budget.
|
||||
User-provided values override auto-inference.
|
||||
Example: [2048, 4096, 8192, 13824]"""
|
||||
|
||||
encoder_cudagraph_max_images_per_batch: int = 0
|
||||
"""Maximum number of images per batch for encoder CUDA graph capture.
|
||||
Determines the fixed batch size used during graph capture.
|
||||
If 0 (default), auto-inferred as max_budget // min_budget from the
|
||||
model's budget range. User-provided positive value overrides
|
||||
auto-inference."""
|
||||
|
||||
# Inductor capture
|
||||
compile_sizes: list[int | str] | None = None
|
||||
"""Sizes to compile for inductor. In addition
|
||||
@@ -906,6 +928,16 @@ class CompilationConfig:
|
||||
f"Invalid backend for piecewise compilation: {self.backend}"
|
||||
)
|
||||
|
||||
# Validate encoder CUDA graph configuration
|
||||
if (
|
||||
self.cudagraph_mm_encoder
|
||||
and self.encoder_cudagraph_max_images_per_batch < 0
|
||||
):
|
||||
raise ValueError(
|
||||
"encoder_cudagraph_max_images_per_batch must be "
|
||||
"non-negative (0 = auto-infer)"
|
||||
)
|
||||
|
||||
if self.backend == "":
|
||||
self.backend = current_platform.get_compile_backend()
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from collections.abc import (
|
||||
from contextlib import ExitStack, contextmanager, nullcontext
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Literal,
|
||||
Protocol,
|
||||
@@ -46,6 +47,11 @@ if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.multimodal.registry import _ProcessorFactories
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
EncoderCudaGraphConfig,
|
||||
EncoderCudaGraphReplayBuffers,
|
||||
)
|
||||
else:
|
||||
VllmConfig = object
|
||||
WeightsMapper = object
|
||||
@@ -1494,3 +1500,138 @@ def supports_xdrope(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
|
||||
return isinstance(model, SupportsXDRoPE)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsEncoderCudaGraph(Protocol):
|
||||
"""Interface for models whose vision encoder supports CUDA graph
|
||||
capture/replay.
|
||||
|
||||
Models implement these methods to provide the
|
||||
:class:`EncoderCudaGraphManager` with all model-specific logic
|
||||
(input handling, metadata computation, forward pass) without the
|
||||
manager needing to know model internals.
|
||||
"""
|
||||
|
||||
supports_encoder_cudagraph: ClassVar[Literal[True]] = True
|
||||
|
||||
def get_encoder_cudagraph_config(self) -> "EncoderCudaGraphConfig": ...
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
) -> tuple[int, int]:
|
||||
"""Return (min_token_budget, max_token_budget) for auto-inference.
|
||||
|
||||
- min_token_budget: estimated smallest possible encoder input
|
||||
(e.g. 64 for a 224x224 image)
|
||||
- max_token_budget: estimated largest budget worth capturing
|
||||
(e.g. max_num_batched_tokens)
|
||||
|
||||
Used when ``encoder_cudagraph_token_budgets`` and/or
|
||||
``encoder_cudagraph_max_images_per_batch`` are not explicitly
|
||||
specified by the user.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
"""Return the number of items (e.g. images) in the batch."""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""Return output token count for each item.
|
||||
|
||||
Used for greedy packing and DP load balancing.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
"""Return input size (e.g. patch count) for each item.
|
||||
|
||||
Used for input tensor slicing offsets.
|
||||
"""
|
||||
...
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
"""Select a subset of items and return mm_kwargs for the sub-batch.
|
||||
|
||||
Called by the manager during greedy packing and DP sharding to
|
||||
extract inputs for a specific set of items (e.g. images at
|
||||
indices [0, 3, 5]). The implementation is model-specific
|
||||
because input formats differ:
|
||||
|
||||
- Qwen-family: slice concatenated pixel_values by cumulative
|
||||
patch offsets, subset grid_thw by indices.
|
||||
- Batched models (CLIP): index pixel_values along dim 0.
|
||||
"""
|
||||
...
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> "EncoderCudaGraphCaptureInputs":
|
||||
"""Create dummy inputs and buffers for CUDA graph capture."""
|
||||
...
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
) -> "EncoderCudaGraphReplayBuffers":
|
||||
"""Compute buffer values from actual batch inputs for replay."""
|
||||
...
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Run the encoder forward pass with precomputed buffers.
|
||||
|
||||
Used during both CUDA graph capture and replay.
|
||||
"""
|
||||
...
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
"""Run the encoder forward pass without precomputed buffers.
|
||||
|
||||
Used as eager fallback when inputs exceed all budgets.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_encoder_cudagraph(
|
||||
model: type[object],
|
||||
) -> TypeIs[type[SupportsEncoderCudaGraph]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_encoder_cudagraph(
|
||||
model: object,
|
||||
) -> TypeIs[SupportsEncoderCudaGraph]: ...
|
||||
|
||||
|
||||
def supports_encoder_cudagraph(
|
||||
model: type[object] | object,
|
||||
) -> TypeIs[type[SupportsEncoderCudaGraph]] | TypeIs[SupportsEncoderCudaGraph]:
|
||||
return isinstance(model, SupportsEncoderCudaGraph)
|
||||
|
||||
@@ -103,6 +103,7 @@ from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsEagle,
|
||||
SupportsEagle3,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
@@ -528,54 +529,120 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
def forward(
|
||||
def prepare_encoder_metadata(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
grid_thw_list: list[list[int]],
|
||||
*,
|
||||
max_batch_size: int | None = None,
|
||||
max_seqlen_override: int | None = None,
|
||||
device: torch.device | None = None,
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
"""Compute encoder metadata from grid_thw_list.
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
grid_thw = np.array(grid_thw, dtype=np.int32)
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
grid_thw = grid_thw.numpy()
|
||||
Shared by the eager forward path, CUDA graph capture, and
|
||||
CUDA graph replay to avoid duplicated implementation.
|
||||
|
||||
pos_embeds = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)
|
||||
Args:
|
||||
grid_thw_list: Grid configurations as list of [t, h, w].
|
||||
max_batch_size: If set, pad cu_seqlens to this size
|
||||
(needed for CUDA graph capture/replay).
|
||||
max_seqlen_override: If set, use this value for max_seqlen
|
||||
instead of computing from cu_seqlens (needed for CUDA
|
||||
graph capture to cover worst-case replay scenarios).
|
||||
device: Device to place tensors on. Defaults to self.device.
|
||||
"""
|
||||
if device is None:
|
||||
device = self.device
|
||||
|
||||
cu_seqlens = np.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
||||
axis=0, dtype=np.int32
|
||||
metadata: dict[str, torch.Tensor | None] = {}
|
||||
|
||||
# Positional embeddings
|
||||
metadata["pos_embeds"] = self.fast_pos_embed_interpolate(grid_thw_list)
|
||||
rotary_cos, rotary_sin = self.rot_pos_emb(grid_thw_list)
|
||||
metadata["rotary_pos_emb_cos"] = rotary_cos
|
||||
metadata["rotary_pos_emb_sin"] = rotary_sin
|
||||
|
||||
# cu_seqlens from grid_thw
|
||||
grid_thw_np = np.array(grid_thw_list, dtype=np.int32)
|
||||
patches_per_frame = grid_thw_np[:, 1] * grid_thw_np[:, 2]
|
||||
cu_seqlens = np.repeat(patches_per_frame, grid_thw_np[:, 0]).cumsum(
|
||||
dtype=np.int32
|
||||
)
|
||||
cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
|
||||
sequence_lengths = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, self.device
|
||||
|
||||
# Pad cu_seqlens if max_batch_size specified
|
||||
if max_batch_size is not None:
|
||||
num_seqs = len(cu_seqlens) - 1
|
||||
if num_seqs < max_batch_size:
|
||||
cu_seqlens = np.concatenate(
|
||||
[
|
||||
cu_seqlens,
|
||||
np.full(
|
||||
max_batch_size - num_seqs,
|
||||
cu_seqlens[-1],
|
||||
dtype=np.int32,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# sequence_lengths (backend-specific)
|
||||
metadata["sequence_lengths"] = MMEncoderAttention.maybe_compute_seq_lens(
|
||||
self.attn_backend, cu_seqlens, device
|
||||
)
|
||||
max_seqlen = torch.tensor(
|
||||
MMEncoderAttention.compute_max_seqlen(self.attn_backend, cu_seqlens),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
cu_seqlens = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
|
||||
# max_seqlen
|
||||
if max_seqlen_override is not None:
|
||||
max_seqlen_val = max_seqlen_override
|
||||
else:
|
||||
max_seqlen_val = MMEncoderAttention.compute_max_seqlen(
|
||||
self.attn_backend, cu_seqlens
|
||||
)
|
||||
# Keep max_seqlen on CPU: attention wrappers call .item() on it,
|
||||
# and having it on GPU would capture a wasteful D2H copy in CUDA
|
||||
# graphs without changing behavior (the scalar is baked at capture).
|
||||
metadata["max_seqlen"] = torch.tensor(max_seqlen_val, dtype=torch.int32)
|
||||
|
||||
# Recompute cu_seqlens (backend-specific transformation)
|
||||
metadata["cu_seqlens"] = MMEncoderAttention.maybe_recompute_cu_seqlens(
|
||||
self.attn_backend,
|
||||
cu_seqlens,
|
||||
self.hidden_size,
|
||||
self.tp_size,
|
||||
self.device,
|
||||
device,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: torch.Tensor | list[list[int]],
|
||||
*,
|
||||
encoder_metadata: dict[str, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
if encoder_metadata is None:
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw_list = grid_thw
|
||||
else:
|
||||
grid_thw_list = grid_thw.tolist()
|
||||
encoder_metadata = self.prepare_encoder_metadata(grid_thw_list)
|
||||
|
||||
pos_embeds = encoder_metadata["pos_embeds"]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
deepstack_feature_lists = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
hidden_states = blk(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb_cos=rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin=rotary_pos_emb_sin,
|
||||
max_seqlen=max_seqlen,
|
||||
sequence_lengths=sequence_lengths,
|
||||
cu_seqlens=encoder_metadata["cu_seqlens"],
|
||||
rotary_pos_emb_cos=encoder_metadata["rotary_pos_emb_cos"],
|
||||
rotary_pos_emb_sin=encoder_metadata["rotary_pos_emb_sin"],
|
||||
max_seqlen=encoder_metadata["max_seqlen"],
|
||||
sequence_lengths=encoder_metadata.get("sequence_lengths"),
|
||||
)
|
||||
if layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_merger_idx = self.deepstack_visual_indexes.index(layer_num)
|
||||
@@ -1358,6 +1425,7 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
|
||||
class Qwen3VLForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsEncoderCudaGraph,
|
||||
SupportsLoRA,
|
||||
SupportsPP,
|
||||
SupportsMRoPE,
|
||||
@@ -1507,6 +1575,178 @@ class Qwen3VLForConditionalGeneration(
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
|
||||
# -- SupportsEncoderCudaGraph protocol methods --
|
||||
|
||||
def get_encoder_cudagraph_config(self):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphConfig,
|
||||
)
|
||||
|
||||
return EncoderCudaGraphConfig(
|
||||
modalities=["image"],
|
||||
input_key="pixel_values",
|
||||
buffer_keys=[
|
||||
"pos_embeds",
|
||||
"rotary_pos_emb_cos",
|
||||
"rotary_pos_emb_sin",
|
||||
"cu_seqlens",
|
||||
"max_seqlen",
|
||||
"sequence_lengths",
|
||||
],
|
||||
out_hidden_size=self.visual.out_hidden_size,
|
||||
)
|
||||
|
||||
def get_encoder_cudagraph_budget_range(
|
||||
self,
|
||||
vllm_config,
|
||||
) -> tuple[int, int]:
|
||||
# Min: estimated smallest possible encoder input.
|
||||
# 224x224 image → 16x16 patches, spatial_merge_size=2 → 8x8 = 64 tokens
|
||||
min_budget = 64
|
||||
# Max: capped by max_num_batched_tokens
|
||||
max_budget = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
return (min_budget, max_budget)
|
||||
|
||||
def get_encoder_cudagraph_num_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> int:
|
||||
return len(mm_kwargs["image_grid_thw"])
|
||||
|
||||
def get_encoder_cudagraph_per_item_output_tokens(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
m = self.visual.spatial_merge_size
|
||||
return [t * (h // m) * (w // m) for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def get_encoder_cudagraph_per_item_input_sizes(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[int]:
|
||||
return [t * h * w for t, h, w in mm_kwargs["image_grid_thw"]]
|
||||
|
||||
def select_encoder_cudagraph_items(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
indices: list[int],
|
||||
) -> dict[str, Any]:
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
|
||||
if len(indices) == 0:
|
||||
return {
|
||||
"pixel_values": pixel_values[:0],
|
||||
"image_grid_thw": [],
|
||||
}
|
||||
|
||||
# Compute cumulative patch offsets for slicing pixel_values
|
||||
patches_per_item = [t * h * w for t, h, w in grid_thw]
|
||||
cum_patches = [0]
|
||||
for p in patches_per_item:
|
||||
cum_patches.append(cum_patches[-1] + p)
|
||||
|
||||
selected_pv = torch.cat(
|
||||
[pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
|
||||
)
|
||||
selected_grid = [grid_thw[i] for i in indices]
|
||||
|
||||
return {
|
||||
"pixel_values": selected_pv,
|
||||
"image_grid_thw": selected_grid,
|
||||
}
|
||||
|
||||
def prepare_encoder_cudagraph_capture_inputs(
|
||||
self,
|
||||
token_budget: int,
|
||||
max_batch_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphCaptureInputs,
|
||||
)
|
||||
|
||||
spatial_merge_size = self.visual.spatial_merge_size
|
||||
per_image_output = token_budget // max_batch_size
|
||||
|
||||
# Synthetic rectangular grid: [1, merge, per_image_output * merge]
|
||||
# produces exactly per_image_output tokens per image.
|
||||
grid_config = [
|
||||
[1, spatial_merge_size, per_image_output * spatial_merge_size]
|
||||
for _ in range(max_batch_size)
|
||||
]
|
||||
|
||||
# Create dummy pixel_values
|
||||
patch_embed = self.visual.patch_embed
|
||||
in_channels = patch_embed.proj.in_channels
|
||||
patch_size = patch_embed.patch_size
|
||||
temporal_patch_size = patch_embed.temporal_patch_size
|
||||
total_patches = sum(t * h * w for t, h, w in grid_config)
|
||||
flattened_patch_size = (
|
||||
in_channels * temporal_patch_size * patch_size * patch_size
|
||||
)
|
||||
dummy_pixel_values = torch.randn(
|
||||
total_patches, flattened_patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Override max_seqlen with a safe upper bound for capture.
|
||||
# max_seqlen.item() gets baked into the CUDA graph (not replayed),
|
||||
# so the capture value must cover any replay scenario.
|
||||
# Worst case: 1 image consuming the full budget ->
|
||||
# seq_len = token_budget * spatial_merge_size^2.
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seqlen_override=token_budget * (spatial_merge_size**2),
|
||||
device=device,
|
||||
)
|
||||
|
||||
mm_kwargs = {
|
||||
"pixel_values": dummy_pixel_values,
|
||||
"image_grid_thw": grid_config,
|
||||
}
|
||||
|
||||
return EncoderCudaGraphCaptureInputs(
|
||||
mm_kwargs=mm_kwargs,
|
||||
buffers=buffers,
|
||||
)
|
||||
|
||||
def prepare_encoder_cudagraph_replay_buffers(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
max_batch_size: int,
|
||||
):
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphReplayBuffers,
|
||||
)
|
||||
|
||||
grid_thw_list = mm_kwargs["image_grid_thw"]
|
||||
|
||||
buffers = self.visual.prepare_encoder_metadata(
|
||||
grid_thw_list,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
return EncoderCudaGraphReplayBuffers(buffers=buffers)
|
||||
|
||||
def encoder_cudagraph_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
buffers: dict[str, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)
|
||||
|
||||
def encoder_eager_forward(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = mm_kwargs["pixel_values"]
|
||||
grid_thw = mm_kwargs["image_grid_thw"]
|
||||
return self.visual(pixel_values, grid_thw)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5_VLImageInputs | None:
|
||||
|
||||
576
vllm/v1/worker/gpu/mm/encoder_cudagraph.py
Normal file
576
vllm/v1/worker/gpu/mm/encoder_cudagraph.py
Normal file
@@ -0,0 +1,576 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""CUDA graph manager for vision encoder budget-batch execution."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsEncoderCudaGraph
|
||||
from vllm.model_executor.models.vision import get_load_balance_assignment
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
||||
EncoderCudaGraphConfig,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetGraphMetadata:
|
||||
"""Metadata for a single budget graph.
|
||||
|
||||
CUDA graph replay pattern:
|
||||
1. Copy new batch data into input_buffer (e.g. pixel_values)
|
||||
2. Copy precomputed values into metadata_buffers
|
||||
3. Replay graph
|
||||
4. Read encoder outputs from output_buffer
|
||||
"""
|
||||
|
||||
token_budget: int
|
||||
max_batch_size: int # Max number of images/videos per batch
|
||||
graph: torch.cuda.CUDAGraph
|
||||
# The input tensor updated before replay (e.g. pixel_values)
|
||||
input_buffer: torch.Tensor
|
||||
# Buffers recorded into the CUDA graph (e.g. embeddings, sequence metadata).
|
||||
# Before replay the manager zeros then slice-copies new data into these.
|
||||
metadata_buffers: dict[str, torch.Tensor]
|
||||
# Output written by graph, read after replay
|
||||
output_buffer: torch.Tensor
|
||||
|
||||
|
||||
class EncoderCudaGraphManager:
|
||||
"""Budget-based CUDA graph capture/replay for vision encoders."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
model: SupportsEncoderCudaGraph,
|
||||
):
|
||||
"""Initialize CUDA graph manager with provided token budgets
|
||||
and max batch size."""
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = model
|
||||
self.config: EncoderCudaGraphConfig = model.get_encoder_cudagraph_config()
|
||||
|
||||
comp_config = vllm_config.compilation_config
|
||||
user_budgets = comp_config.encoder_cudagraph_token_budgets
|
||||
user_max_images = comp_config.encoder_cudagraph_max_images_per_batch
|
||||
|
||||
if user_budgets and user_max_images > 0:
|
||||
# Fully user-specified
|
||||
self.token_budgets = sorted(user_budgets)
|
||||
self.max_batch_size = user_max_images
|
||||
else:
|
||||
# Auto-infer missing values from model
|
||||
min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
|
||||
vllm_config
|
||||
)
|
||||
self.token_budgets = (
|
||||
sorted(user_budgets)
|
||||
if user_budgets
|
||||
else self._generate_budgets(min_budget, max_budget)
|
||||
)
|
||||
self.max_batch_size = (
|
||||
user_max_images if user_max_images > 0 else max_budget // min_budget
|
||||
)
|
||||
|
||||
mm_config = vllm_config.model_config.multimodal_config
|
||||
self.use_dp = (
|
||||
mm_config is not None
|
||||
and mm_config.mm_encoder_tp_mode == "data"
|
||||
and vllm_config.parallel_config.tensor_parallel_size > 1
|
||||
)
|
||||
|
||||
self.budget_graphs: dict[int, BudgetGraphMetadata] = {}
|
||||
self.graph_hits = 0
|
||||
self.graph_misses = 0
|
||||
self.log_stats_interval = 100
|
||||
|
||||
logger.info(
|
||||
"EncoderCudaGraphManager initialized with "
|
||||
"budgets=%s, max_batch_size=%d, use_dp=%s",
|
||||
self.token_budgets,
|
||||
self.max_batch_size,
|
||||
self.use_dp,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_budgets(min_budget: int, max_budget: int) -> list[int]:
|
||||
"""Generate power-of-2 token budgets from min_budget to max_budget."""
|
||||
budgets: list[int] = []
|
||||
b = min_budget
|
||||
while b <= max_budget:
|
||||
budgets.append(b)
|
||||
b *= 2
|
||||
# Always include max_budget if it's not already a power-of-2 boundary
|
||||
if not budgets or budgets[-1] < max_budget:
|
||||
budgets.append(max_budget)
|
||||
return budgets
|
||||
|
||||
def supports_modality(self, modality: str) -> bool:
|
||||
"""Check if a modality is supported by this manager."""
|
||||
return modality in self.config.modalities
|
||||
|
||||
def capture(self):
|
||||
"""Capture CUDA graphs for all token budgets."""
|
||||
for token_budget in self.token_budgets:
|
||||
self._capture_budget_graph(token_budget)
|
||||
|
||||
logger.info(
|
||||
"Encoder CUDA graph capture complete. Captured %d budget graphs.",
|
||||
len(self.budget_graphs),
|
||||
)
|
||||
|
||||
def _capture_budget_graph(self, token_budget: int):
|
||||
"""Capture CUDA graph for a single token budget."""
|
||||
logger.debug(
|
||||
"Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
|
||||
token_budget,
|
||||
self.max_batch_size,
|
||||
)
|
||||
|
||||
capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
|
||||
token_budget, self.max_batch_size, self.device, self.dtype
|
||||
)
|
||||
|
||||
mm_kwargs = capture_inputs.mm_kwargs
|
||||
buffers = capture_inputs.buffers
|
||||
|
||||
with torch.inference_mode():
|
||||
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
|
||||
output_buffer = torch.empty_like(output)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode(), torch.cuda.graph(graph):
|
||||
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
|
||||
output_buffer.copy_(output)
|
||||
|
||||
input_key = self.config.input_key
|
||||
self.budget_graphs[token_budget] = BudgetGraphMetadata(
|
||||
token_budget=token_budget,
|
||||
max_batch_size=self.max_batch_size,
|
||||
graph=graph,
|
||||
input_buffer=mm_kwargs[input_key],
|
||||
metadata_buffers=buffers,
|
||||
output_buffer=output_buffer,
|
||||
)
|
||||
|
||||
def _find_smallest_fitting_budget_given_tokens(
|
||||
self, total_tokens: int
|
||||
) -> int | None:
|
||||
"""Find smallest budget >= total_tokens.
|
||||
|
||||
Returns:
|
||||
Token budget if found, None if no fitting budget.
|
||||
"""
|
||||
for budget in self.token_budgets:
|
||||
if budget >= total_tokens:
|
||||
return budget
|
||||
return None
|
||||
|
||||
def _get_per_item_out_tokens(self, mm_kwargs: dict[str, Any]) -> list[int]:
|
||||
"""Get per-item output token counts as plain ints."""
|
||||
return [
|
||||
int(t)
|
||||
for t in self.model.get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _scatter_output_slices(
|
||||
output: torch.Tensor,
|
||||
indices: list[int],
|
||||
per_item_out_tokens: list[int],
|
||||
dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
|
||||
clone: bool = False,
|
||||
) -> None:
|
||||
"""Slice a concatenated output tensor and scatter into dest by index."""
|
||||
offset = 0
|
||||
for idx in indices:
|
||||
n_tok = per_item_out_tokens[idx]
|
||||
sliced = output[offset : offset + n_tok]
|
||||
dest[idx] = sliced.clone() if clone else sliced
|
||||
offset += n_tok
|
||||
|
||||
def _run_budget_graph(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
token_budget: int,
|
||||
replay_buffers: dict[str, torch.Tensor | None],
|
||||
) -> torch.Tensor | None:
|
||||
"""Execute budget graph.
|
||||
|
||||
Args:
|
||||
mm_kwargs: Multimodal inputs for the batch.
|
||||
token_budget: Token budget to use.
|
||||
replay_buffers: Buffer values to copy into captured buffers.
|
||||
None values leave the corresponding buffer unchanged.
|
||||
|
||||
Returns:
|
||||
Encoder outputs, or None if graph not captured.
|
||||
"""
|
||||
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
|
||||
if token_budget not in self.budget_graphs:
|
||||
self.graph_misses += num_items
|
||||
return None
|
||||
|
||||
graph_meta = self.budget_graphs[token_budget]
|
||||
|
||||
# Copy the input tensor. Buffers are sized for the full budget;
|
||||
# actual inputs may be smaller. Zero then slice-copy so padded
|
||||
# positions are invisible to attention (cu_seqlens masks them out).
|
||||
input_key = self.config.input_key
|
||||
src = mm_kwargs[input_key]
|
||||
n = src.shape[0]
|
||||
graph_meta.input_buffer.zero_()
|
||||
graph_meta.input_buffer[:n].copy_(src)
|
||||
|
||||
# Copy metadata buffers using keys from config.buffer_keys.
|
||||
for key in self.config.buffer_keys:
|
||||
src = replay_buffers.get(key)
|
||||
if src is None:
|
||||
continue
|
||||
buf = graph_meta.metadata_buffers[key]
|
||||
if src.ndim == 0:
|
||||
buf.copy_(src)
|
||||
else:
|
||||
n = src.shape[0]
|
||||
buf.zero_()
|
||||
buf[:n].copy_(src)
|
||||
|
||||
graph_meta.graph.replay()
|
||||
|
||||
self.graph_hits += num_items
|
||||
return graph_meta.output_buffer
|
||||
|
||||
def _execute_local(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Execute encoder on local inputs using greedy-packed CUDA graphs.
|
||||
|
||||
Sort images by output token count (smallest first), then greedily pack
|
||||
as many images as possible into each batch while staying within
|
||||
max_budget tokens and max_batch_size. Once a batch is finalised (next
|
||||
image would overflow either constraint), find the smallest fitting
|
||||
budget once for that batch.
|
||||
|
||||
By exchange argument, greedy smallest-first packing minimises eager
|
||||
fallbacks -- any other ordering yields a higher token sum in some batch,
|
||||
making that batch more likely to exceed the budget.
|
||||
|
||||
Stats note:
|
||||
graph_hits -- counted inside _run_budget_graph after successful replay.
|
||||
graph_misses -- counted here for single-image batches where the image
|
||||
exceeds max_budget. Batches split due to max_batch_size
|
||||
always satisfy total_tokens <= max_budget and therefore
|
||||
always find a valid budget (no miss).
|
||||
"""
|
||||
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
|
||||
max_budget = self.token_budgets[-1]
|
||||
|
||||
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
|
||||
|
||||
# Sort ascending by output token count (smallest first)
|
||||
sorted_indices = sorted(range(num_items), key=lambda i: per_item_out_tokens[i])
|
||||
|
||||
# Greedy pack against max_budget and max_batch_size.
|
||||
# _find_smallest_fitting_budget_given_tokens is called once per
|
||||
# finalised batch, not per image.
|
||||
batches: list[tuple[list[int], int | None]] = []
|
||||
current_batch: list[int] = []
|
||||
current_batch_tokens = 0
|
||||
|
||||
for orig_idx in sorted_indices:
|
||||
item_tokens = per_item_out_tokens[orig_idx]
|
||||
if (
|
||||
current_batch_tokens + item_tokens <= max_budget
|
||||
and len(current_batch) < self.max_batch_size
|
||||
):
|
||||
current_batch.append(orig_idx)
|
||||
current_batch_tokens += item_tokens
|
||||
else:
|
||||
if current_batch:
|
||||
batches.append(
|
||||
(
|
||||
current_batch,
|
||||
self._find_smallest_fitting_budget_given_tokens(
|
||||
current_batch_tokens
|
||||
),
|
||||
)
|
||||
)
|
||||
current_batch = [orig_idx]
|
||||
current_batch_tokens = item_tokens
|
||||
|
||||
if current_batch:
|
||||
batches.append(
|
||||
(
|
||||
current_batch,
|
||||
self._find_smallest_fitting_budget_given_tokens(
|
||||
current_batch_tokens
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# outputs_by_orig_idx maps each original image index to its output
|
||||
# tensor. Needed because greedy packing reorders images; we restore
|
||||
# the original order before returning.
|
||||
outputs_by_orig_idx: dict[int, torch.Tensor] = {}
|
||||
|
||||
for batch_orig_indices, token_budget in batches:
|
||||
batch_mm_kwargs = self.model.select_encoder_cudagraph_items(
|
||||
mm_kwargs, batch_orig_indices
|
||||
)
|
||||
batch_out_tokens = sum(per_item_out_tokens[i] for i in batch_orig_indices)
|
||||
|
||||
if token_budget is None:
|
||||
# Single oversized image: item_tokens > max_budget.
|
||||
# graph_misses counted here for this eager fallback.
|
||||
logger.debug(
|
||||
"Encoder CUDA graph fallback to eager: no budget for "
|
||||
"%d tokens from %d images",
|
||||
batch_out_tokens,
|
||||
len(batch_orig_indices),
|
||||
)
|
||||
self.graph_misses += len(batch_orig_indices)
|
||||
with torch.inference_mode():
|
||||
raw = self.model.encoder_eager_forward(batch_mm_kwargs)
|
||||
self._scatter_output_slices(
|
||||
raw,
|
||||
batch_orig_indices,
|
||||
per_item_out_tokens,
|
||||
outputs_by_orig_idx,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Encoder CUDA graph: batch_size=%d, tokens=%d, "
|
||||
"budget=%d, waste=%.1f%%",
|
||||
len(batch_orig_indices),
|
||||
batch_out_tokens,
|
||||
token_budget,
|
||||
(token_budget - batch_out_tokens) / token_budget * 100,
|
||||
)
|
||||
replay = self.model.prepare_encoder_cudagraph_replay_buffers(
|
||||
batch_mm_kwargs, self.max_batch_size
|
||||
)
|
||||
|
||||
# graph_hits counted inside _run_budget_graph after replay.
|
||||
output = self._run_budget_graph(
|
||||
batch_mm_kwargs, token_budget, replay.buffers
|
||||
)
|
||||
assert output is not None
|
||||
self._scatter_output_slices(
|
||||
output,
|
||||
batch_orig_indices,
|
||||
per_item_out_tokens,
|
||||
outputs_by_orig_idx,
|
||||
clone=True,
|
||||
)
|
||||
|
||||
# Return in original batch order (caller maps outputs to token positions)
|
||||
return [outputs_by_orig_idx[i] for i in range(num_items)]
|
||||
|
||||
def _dp_shard(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
per_item_out_tokens: list[int],
|
||||
) -> tuple[dict[str, Any], list[int], list[int], int]:
|
||||
"""Distribute items across TP ranks for data-parallel execution.
|
||||
|
||||
Uses get_load_balance_assignment() to balance load by input size,
|
||||
then select_encoder_cudagraph_items() to extract each rank's inputs.
|
||||
|
||||
Returns:
|
||||
local_mm_kwargs: Inputs for this rank.
|
||||
image_rank_assignment: Flattened assignment order across all ranks.
|
||||
images_per_rank: Number of items per rank.
|
||||
max_output_tokens_per_rank: Max output tokens across all ranks
|
||||
(for padding during all_gather).
|
||||
"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
current_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
per_item_input_sizes = self.model.get_encoder_cudagraph_per_item_input_sizes(
|
||||
mm_kwargs
|
||||
)
|
||||
|
||||
(image_rank_assignment, images_per_rank, input_patches_per_rank) = (
|
||||
get_load_balance_assignment(per_item_input_sizes, tp_size)
|
||||
)
|
||||
|
||||
# Extract local indices for this rank
|
||||
cum_images_per_rank = [0]
|
||||
for count in images_per_rank:
|
||||
cum_images_per_rank.append(cum_images_per_rank[-1] + count)
|
||||
|
||||
local_indices = image_rank_assignment[
|
||||
cum_images_per_rank[current_rank] : cum_images_per_rank[current_rank + 1]
|
||||
]
|
||||
|
||||
if len(local_indices) > 0:
|
||||
local_mm_kwargs = self.model.select_encoder_cudagraph_items(
|
||||
mm_kwargs, local_indices
|
||||
)
|
||||
else:
|
||||
local_mm_kwargs = self.model.select_encoder_cudagraph_items(mm_kwargs, [])
|
||||
|
||||
max_output_tokens_per_rank = (
|
||||
max(
|
||||
sum(
|
||||
per_item_out_tokens[i]
|
||||
for i in image_rank_assignment[
|
||||
cum_images_per_rank[r] : cum_images_per_rank[r + 1]
|
||||
]
|
||||
)
|
||||
for r in range(tp_size)
|
||||
)
|
||||
if len(per_item_out_tokens) > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return (
|
||||
local_mm_kwargs,
|
||||
image_rank_assignment,
|
||||
images_per_rank,
|
||||
max_output_tokens_per_rank,
|
||||
)
|
||||
|
||||
def _dp_gather(
|
||||
self,
|
||||
local_outputs: list[torch.Tensor],
|
||||
per_item_out_tokens: list[int],
|
||||
image_rank_assignment: list[int],
|
||||
images_per_rank: list[int],
|
||||
max_output_tokens_per_rank: int,
|
||||
) -> list[torch.Tensor]:
|
||||
"""Gather outputs from all TP ranks and reorder to original sequence.
|
||||
|
||||
Assumes 2D output tensors [tokens, hidden]. Follows the same
|
||||
pad -> all_gather -> unpad -> reorder algorithm as
|
||||
run_dp_sharded_mrope_vision_model() in the eager path.
|
||||
"""
|
||||
hidden_size = self.config.out_hidden_size
|
||||
tp_size = len(images_per_rank)
|
||||
|
||||
if len(local_outputs) > 0:
|
||||
local_concat = torch.cat(local_outputs, dim=0)
|
||||
else:
|
||||
local_concat = torch.empty(
|
||||
(0, hidden_size), device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
# Pad to max_output_tokens_per_rank for all_gather
|
||||
current_len = local_concat.shape[0]
|
||||
if current_len < max_output_tokens_per_rank:
|
||||
padding = torch.empty(
|
||||
(max_output_tokens_per_rank - current_len, hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
local_padded = torch.cat([local_concat, padding], dim=0)
|
||||
else:
|
||||
local_padded = local_concat
|
||||
|
||||
gathered = tensor_model_parallel_all_gather(local_padded, dim=0)
|
||||
|
||||
# Unpad each rank's contribution
|
||||
rank_outputs: list[torch.Tensor] = []
|
||||
current_idx = 0
|
||||
for rank in range(tp_size):
|
||||
start = rank * max_output_tokens_per_rank
|
||||
rank_count = images_per_rank[rank]
|
||||
rank_indices = image_rank_assignment[current_idx : current_idx + rank_count]
|
||||
rank_tokens = sum(per_item_out_tokens[i] for i in rank_indices)
|
||||
current_idx += rank_count
|
||||
rank_outputs.append(gathered[start : start + rank_tokens])
|
||||
|
||||
# Reorder to original sequence
|
||||
total_items = len(per_item_out_tokens)
|
||||
result: list[torch.Tensor | None] = [None] * total_items
|
||||
current_idx = 0
|
||||
for rank in range(tp_size):
|
||||
count = images_per_rank[rank]
|
||||
if count > 0:
|
||||
rank_items = image_rank_assignment[current_idx : current_idx + count]
|
||||
self._scatter_output_slices(
|
||||
rank_outputs[rank],
|
||||
rank_items,
|
||||
per_item_out_tokens,
|
||||
result,
|
||||
)
|
||||
current_idx += count
|
||||
|
||||
return [t for t in result if t is not None]
|
||||
|
||||
def execute(
|
||||
self,
|
||||
mm_kwargs: dict[str, Any],
|
||||
) -> list[torch.Tensor]:
|
||||
"""Execute encoder using CUDA graph with optional DP.
|
||||
|
||||
Args:
|
||||
mm_kwargs: Multimodal keyword arguments containing the
|
||||
input tensor and grid dimensions.
|
||||
|
||||
Returns:
|
||||
List of encoder outputs (one per item).
|
||||
"""
|
||||
if self.use_dp:
|
||||
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
|
||||
|
||||
(
|
||||
local_mm_kwargs,
|
||||
image_rank_assignment,
|
||||
images_per_rank,
|
||||
max_output_tokens_per_rank,
|
||||
) = self._dp_shard(mm_kwargs, per_item_out_tokens)
|
||||
|
||||
local_outputs = self._execute_local(local_mm_kwargs)
|
||||
|
||||
result = self._dp_gather(
|
||||
local_outputs,
|
||||
per_item_out_tokens,
|
||||
image_rank_assignment,
|
||||
images_per_rank,
|
||||
max_output_tokens_per_rank,
|
||||
)
|
||||
else:
|
||||
result = self._execute_local(mm_kwargs)
|
||||
|
||||
# Log cumulative stats periodically
|
||||
stats = self.get_cumulative_stats()
|
||||
total_requests = self.graph_hits + self.graph_misses
|
||||
if total_requests > 0 and total_requests % self.log_stats_interval == 0:
|
||||
logger.debug(
|
||||
"Encoder CUDA graph cumulative stats: "
|
||||
"hits=%d, misses=%d, hit_rate=%.1f%%",
|
||||
stats["graph_hits"],
|
||||
stats["graph_misses"],
|
||||
stats["hit_rate"] * 100,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_cumulative_stats(self) -> dict[str, Any]:
|
||||
"""Get cumulative CUDA graph statistics."""
|
||||
total_requests = self.graph_hits + self.graph_misses
|
||||
hit_rate = self.graph_hits / total_requests if total_requests > 0 else 0.0
|
||||
|
||||
return {
|
||||
"graph_hits": self.graph_hits,
|
||||
"graph_misses": self.graph_misses,
|
||||
"hit_rate": hit_rate,
|
||||
"num_budgets": len(self.budget_graphs),
|
||||
"token_budgets": self.token_budgets,
|
||||
}
|
||||
66
vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.py
Normal file
66
vllm/v1/worker/gpu/mm/encoder_cudagraph_defs.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Data transfer objects for encoder CUDA graph management."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderCudaGraphConfig:
|
||||
"""Configuration for encoder CUDA graph management.
|
||||
|
||||
Provided by the model at init time via
|
||||
``get_encoder_cudagraph_config()``. Values are fixed for the
|
||||
lifetime of the manager.
|
||||
"""
|
||||
|
||||
modalities: list[str]
|
||||
"""Supported modalities (e.g. ["image"])."""
|
||||
|
||||
input_key: str
|
||||
"""Key in mm_kwargs for the input tensor (e.g. "pixel_values")."""
|
||||
|
||||
buffer_keys: list[str]
|
||||
"""Keys for the tensor buffers recorded into the CUDA graph.
|
||||
Before replay the manager zeros then slice-copies new data
|
||||
into these buffers."""
|
||||
|
||||
out_hidden_size: int
|
||||
"""Output hidden dim of the vision encoder.
|
||||
Used for DP gather buffer allocation."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderCudaGraphCaptureInputs:
|
||||
"""Everything needed for one CUDA graph capture.
|
||||
|
||||
Returned by ``prepare_encoder_cudagraph_capture_inputs()``.
|
||||
"""
|
||||
|
||||
mm_kwargs: dict[str, Any]
|
||||
"""Dummy forward inputs (model-specific keys).
|
||||
For Qwen3-VL this contains pixel_values and grid_thw."""
|
||||
|
||||
buffers: dict[str, torch.Tensor]
|
||||
"""Precomputed tensor buffers that will be recorded into the
|
||||
CUDA graph. The manager stores references to these exact
|
||||
tensor objects and copies new data into them before each
|
||||
``graph.replay()`` call (buffer identity invariant)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderCudaGraphReplayBuffers:
|
||||
"""New buffer values for graph replay, computed by the model from
|
||||
actual batch inputs.
|
||||
|
||||
Returned by ``prepare_encoder_cudagraph_replay_buffers()``.
|
||||
Keys match ``EncoderCudaGraphConfig.buffer_keys``.
|
||||
"""
|
||||
|
||||
buffers: dict[str, torch.Tensor | None]
|
||||
"""Data to copy into the captured buffers before replay.
|
||||
``None`` values leave the corresponding captured buffer
|
||||
unchanged."""
|
||||
@@ -207,6 +207,7 @@ from .utils import (
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import EncoderCudaGraphManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -499,6 +500,9 @@ class GPUModelRunner(
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
self.late_interaction_runner = LateInteractionRunner()
|
||||
|
||||
# Encoder CUDA graph manager (initialized after model load if enabled)
|
||||
self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None
|
||||
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
# Set up speculative decoding.
|
||||
# NOTE(Jiayi): currently we put the entire draft model on
|
||||
@@ -2664,7 +2668,19 @@ class GPUModelRunner(
|
||||
with self.timed_encoder_operation(
|
||||
should_time, mm_lora_refs, current_item_idx, num_items
|
||||
):
|
||||
batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
|
||||
cudagraph_output = None
|
||||
if (
|
||||
self.encoder_cudagraph_manager is not None
|
||||
and self.encoder_cudagraph_manager.supports_modality(modality)
|
||||
):
|
||||
cudagraph_output = self.encoder_cudagraph_manager.execute(
|
||||
mm_kwargs_batch,
|
||||
)
|
||||
|
||||
if cudagraph_output is not None:
|
||||
batch_outputs = cudagraph_output
|
||||
else:
|
||||
batch_outputs = model.embed_multimodal(**mm_kwargs_batch)
|
||||
|
||||
sanity_check_mm_encoder_outputs(batch_outputs, expected_num_items=num_items)
|
||||
encoder_outputs.extend(batch_outputs)
|
||||
@@ -5715,6 +5731,33 @@ class GPUModelRunner(
|
||||
)
|
||||
return 0
|
||||
|
||||
# Initialize encoder CUDA graph manager if enabled.
|
||||
# Use get_model() to unwrap CUDAGraphWrapper/UBatchWrapper,
|
||||
# because @runtime_checkable Protocol isinstance() checks do not
|
||||
# work through __getattr__ forwarding.
|
||||
if (
|
||||
self.compilation_config.cudagraph_mm_encoder
|
||||
and self.supports_mm_inputs
|
||||
and self.encoder_cudagraph_manager is None
|
||||
):
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsEncoderCudaGraph,
|
||||
supports_encoder_cudagraph,
|
||||
)
|
||||
from vllm.v1.worker.gpu.mm.encoder_cudagraph import (
|
||||
EncoderCudaGraphManager,
|
||||
)
|
||||
|
||||
raw_model = self.get_model()
|
||||
if supports_encoder_cudagraph(raw_model):
|
||||
self.encoder_cudagraph_manager = EncoderCudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
model=cast(SupportsEncoderCudaGraph, raw_model),
|
||||
)
|
||||
logger.info("Initialized EncoderCudaGraphManager for vision encoder")
|
||||
|
||||
compilation_counter.num_gpu_runner_capture_triggers += 1
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@@ -5738,6 +5781,10 @@ class GPUModelRunner(
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
# Capture encoder CUDA graphs if enabled
|
||||
if self.encoder_cudagraph_manager is not None:
|
||||
self.encoder_cudagraph_manager.capture()
|
||||
|
||||
torch.accelerator.synchronize()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user