577 lines
21 KiB
Python
577 lines
21 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""CUDA graph manager for vision encoder budget-batch execution."""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_gather,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.models.interfaces import SupportsEncoderCudaGraph
|
|
from vllm.model_executor.models.vision import get_load_balance_assignment
|
|
from vllm.v1.worker.gpu.mm.encoder_cudagraph_defs import (
|
|
EncoderCudaGraphConfig,
|
|
)
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BudgetGraphMetadata:
|
|
"""Metadata for a single budget graph.
|
|
|
|
CUDA graph replay pattern:
|
|
1. Copy new batch data into input_buffer (e.g. pixel_values)
|
|
2. Copy precomputed values into metadata_buffers
|
|
3. Replay graph
|
|
4. Read encoder outputs from output_buffer
|
|
"""
|
|
|
|
token_budget: int
|
|
max_batch_size: int # Max number of images/videos per batch
|
|
graph: torch.cuda.CUDAGraph
|
|
# The input tensor updated before replay (e.g. pixel_values)
|
|
input_buffer: torch.Tensor
|
|
# Buffers recorded into the CUDA graph (e.g. embeddings, sequence metadata).
|
|
# Before replay the manager zeros then slice-copies new data into these.
|
|
metadata_buffers: dict[str, torch.Tensor]
|
|
# Output written by graph, read after replay
|
|
output_buffer: torch.Tensor
|
|
|
|
|
|
class EncoderCudaGraphManager:
|
|
"""Budget-based CUDA graph capture/replay for vision encoders."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
model: SupportsEncoderCudaGraph,
|
|
):
|
|
"""Initialize CUDA graph manager with provided token budgets
|
|
and max batch size."""
|
|
self.vllm_config = vllm_config
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.model = model
|
|
self.config: EncoderCudaGraphConfig = model.get_encoder_cudagraph_config()
|
|
|
|
comp_config = vllm_config.compilation_config
|
|
user_budgets = comp_config.encoder_cudagraph_token_budgets
|
|
user_max_images = comp_config.encoder_cudagraph_max_images_per_batch
|
|
|
|
if user_budgets and user_max_images > 0:
|
|
# Fully user-specified
|
|
self.token_budgets = sorted(user_budgets)
|
|
self.max_batch_size = user_max_images
|
|
else:
|
|
# Auto-infer missing values from model
|
|
min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
|
|
vllm_config
|
|
)
|
|
self.token_budgets = (
|
|
sorted(user_budgets)
|
|
if user_budgets
|
|
else self._generate_budgets(min_budget, max_budget)
|
|
)
|
|
self.max_batch_size = (
|
|
user_max_images if user_max_images > 0 else max_budget // min_budget
|
|
)
|
|
|
|
mm_config = vllm_config.model_config.multimodal_config
|
|
self.use_dp = (
|
|
mm_config is not None
|
|
and mm_config.mm_encoder_tp_mode == "data"
|
|
and vllm_config.parallel_config.tensor_parallel_size > 1
|
|
)
|
|
|
|
self.budget_graphs: dict[int, BudgetGraphMetadata] = {}
|
|
self.graph_hits = 0
|
|
self.graph_misses = 0
|
|
self.log_stats_interval = 100
|
|
|
|
logger.info(
|
|
"EncoderCudaGraphManager initialized with "
|
|
"budgets=%s, max_batch_size=%d, use_dp=%s",
|
|
self.token_budgets,
|
|
self.max_batch_size,
|
|
self.use_dp,
|
|
)
|
|
|
|
@staticmethod
|
|
def _generate_budgets(min_budget: int, max_budget: int) -> list[int]:
|
|
"""Generate power-of-2 token budgets from min_budget to max_budget."""
|
|
budgets: list[int] = []
|
|
b = min_budget
|
|
while b <= max_budget:
|
|
budgets.append(b)
|
|
b *= 2
|
|
# Always include max_budget if it's not already a power-of-2 boundary
|
|
if not budgets or budgets[-1] < max_budget:
|
|
budgets.append(max_budget)
|
|
return budgets
|
|
|
|
def supports_modality(self, modality: str) -> bool:
|
|
"""Check if a modality is supported by this manager."""
|
|
return modality in self.config.modalities
|
|
|
|
def capture(self):
|
|
"""Capture CUDA graphs for all token budgets."""
|
|
for token_budget in self.token_budgets:
|
|
self._capture_budget_graph(token_budget)
|
|
|
|
logger.info(
|
|
"Encoder CUDA graph capture complete. Captured %d budget graphs.",
|
|
len(self.budget_graphs),
|
|
)
|
|
|
|
def _capture_budget_graph(self, token_budget: int):
|
|
"""Capture CUDA graph for a single token budget."""
|
|
logger.debug(
|
|
"Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
|
|
token_budget,
|
|
self.max_batch_size,
|
|
)
|
|
|
|
capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
|
|
token_budget, self.max_batch_size, self.device, self.dtype
|
|
)
|
|
|
|
mm_kwargs = capture_inputs.mm_kwargs
|
|
buffers = capture_inputs.buffers
|
|
|
|
with torch.inference_mode():
|
|
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
|
|
output_buffer = torch.empty_like(output)
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.inference_mode(), torch.cuda.graph(graph):
|
|
output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
|
|
output_buffer.copy_(output)
|
|
|
|
input_key = self.config.input_key
|
|
self.budget_graphs[token_budget] = BudgetGraphMetadata(
|
|
token_budget=token_budget,
|
|
max_batch_size=self.max_batch_size,
|
|
graph=graph,
|
|
input_buffer=mm_kwargs[input_key],
|
|
metadata_buffers=buffers,
|
|
output_buffer=output_buffer,
|
|
)
|
|
|
|
def _find_smallest_fitting_budget_given_tokens(
|
|
self, total_tokens: int
|
|
) -> int | None:
|
|
"""Find smallest budget >= total_tokens.
|
|
|
|
Returns:
|
|
Token budget if found, None if no fitting budget.
|
|
"""
|
|
for budget in self.token_budgets:
|
|
if budget >= total_tokens:
|
|
return budget
|
|
return None
|
|
|
|
def _get_per_item_out_tokens(self, mm_kwargs: dict[str, Any]) -> list[int]:
|
|
"""Get per-item output token counts as plain ints."""
|
|
return [
|
|
int(t)
|
|
for t in self.model.get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)
|
|
]
|
|
|
|
@staticmethod
|
|
def _scatter_output_slices(
|
|
output: torch.Tensor,
|
|
indices: list[int],
|
|
per_item_out_tokens: list[int],
|
|
dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
|
|
clone: bool = False,
|
|
) -> None:
|
|
"""Slice a concatenated output tensor and scatter into dest by index."""
|
|
offset = 0
|
|
for idx in indices:
|
|
n_tok = per_item_out_tokens[idx]
|
|
sliced = output[offset : offset + n_tok]
|
|
dest[idx] = sliced.clone() if clone else sliced
|
|
offset += n_tok
|
|
|
|
def _run_budget_graph(
|
|
self,
|
|
mm_kwargs: dict[str, Any],
|
|
token_budget: int,
|
|
replay_buffers: dict[str, torch.Tensor | None],
|
|
) -> torch.Tensor | None:
|
|
"""Execute budget graph.
|
|
|
|
Args:
|
|
mm_kwargs: Multimodal inputs for the batch.
|
|
token_budget: Token budget to use.
|
|
replay_buffers: Buffer values to copy into captured buffers.
|
|
None values leave the corresponding buffer unchanged.
|
|
|
|
Returns:
|
|
Encoder outputs, or None if graph not captured.
|
|
"""
|
|
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
|
|
if token_budget not in self.budget_graphs:
|
|
self.graph_misses += num_items
|
|
return None
|
|
|
|
graph_meta = self.budget_graphs[token_budget]
|
|
|
|
# Copy the input tensor. Buffers are sized for the full budget;
|
|
# actual inputs may be smaller. Zero then slice-copy so padded
|
|
# positions are invisible to attention (cu_seqlens masks them out).
|
|
input_key = self.config.input_key
|
|
src = mm_kwargs[input_key]
|
|
n = src.shape[0]
|
|
graph_meta.input_buffer.zero_()
|
|
graph_meta.input_buffer[:n].copy_(src)
|
|
|
|
# Copy metadata buffers using keys from config.buffer_keys.
|
|
for key in self.config.buffer_keys:
|
|
src = replay_buffers.get(key)
|
|
if src is None:
|
|
continue
|
|
buf = graph_meta.metadata_buffers[key]
|
|
if src.ndim == 0:
|
|
buf.copy_(src)
|
|
else:
|
|
n = src.shape[0]
|
|
buf.zero_()
|
|
buf[:n].copy_(src)
|
|
|
|
graph_meta.graph.replay()
|
|
|
|
self.graph_hits += num_items
|
|
return graph_meta.output_buffer
|
|
|
|
def _execute_local(
|
|
self,
|
|
mm_kwargs: dict[str, Any],
|
|
) -> list[torch.Tensor]:
|
|
"""Execute encoder on local inputs using greedy-packed CUDA graphs.
|
|
|
|
Sort images by output token count (smallest first), then greedily pack
|
|
as many images as possible into each batch while staying within
|
|
max_budget tokens and max_batch_size. Once a batch is finalised (next
|
|
image would overflow either constraint), find the smallest fitting
|
|
budget once for that batch.
|
|
|
|
By exchange argument, greedy smallest-first packing minimises eager
|
|
fallbacks -- any other ordering yields a higher token sum in some batch,
|
|
making that batch more likely to exceed the budget.
|
|
|
|
Stats note:
|
|
graph_hits -- counted inside _run_budget_graph after successful replay.
|
|
graph_misses -- counted here for single-image batches where the image
|
|
exceeds max_budget. Batches split due to max_batch_size
|
|
always satisfy total_tokens <= max_budget and therefore
|
|
always find a valid budget (no miss).
|
|
"""
|
|
num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
|
|
max_budget = self.token_budgets[-1]
|
|
|
|
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
|
|
|
|
# Sort ascending by output token count (smallest first)
|
|
sorted_indices = sorted(range(num_items), key=lambda i: per_item_out_tokens[i])
|
|
|
|
# Greedy pack against max_budget and max_batch_size.
|
|
# _find_smallest_fitting_budget_given_tokens is called once per
|
|
# finalised batch, not per image.
|
|
batches: list[tuple[list[int], int | None]] = []
|
|
current_batch: list[int] = []
|
|
current_batch_tokens = 0
|
|
|
|
for orig_idx in sorted_indices:
|
|
item_tokens = per_item_out_tokens[orig_idx]
|
|
if (
|
|
current_batch_tokens + item_tokens <= max_budget
|
|
and len(current_batch) < self.max_batch_size
|
|
):
|
|
current_batch.append(orig_idx)
|
|
current_batch_tokens += item_tokens
|
|
else:
|
|
if current_batch:
|
|
batches.append(
|
|
(
|
|
current_batch,
|
|
self._find_smallest_fitting_budget_given_tokens(
|
|
current_batch_tokens
|
|
),
|
|
)
|
|
)
|
|
current_batch = [orig_idx]
|
|
current_batch_tokens = item_tokens
|
|
|
|
if current_batch:
|
|
batches.append(
|
|
(
|
|
current_batch,
|
|
self._find_smallest_fitting_budget_given_tokens(
|
|
current_batch_tokens
|
|
),
|
|
)
|
|
)
|
|
|
|
# outputs_by_orig_idx maps each original image index to its output
|
|
# tensor. Needed because greedy packing reorders images; we restore
|
|
# the original order before returning.
|
|
outputs_by_orig_idx: dict[int, torch.Tensor] = {}
|
|
|
|
for batch_orig_indices, token_budget in batches:
|
|
batch_mm_kwargs = self.model.select_encoder_cudagraph_items(
|
|
mm_kwargs, batch_orig_indices
|
|
)
|
|
batch_out_tokens = sum(per_item_out_tokens[i] for i in batch_orig_indices)
|
|
|
|
if token_budget is None:
|
|
# Single oversized image: item_tokens > max_budget.
|
|
# graph_misses counted here for this eager fallback.
|
|
logger.debug(
|
|
"Encoder CUDA graph fallback to eager: no budget for "
|
|
"%d tokens from %d images",
|
|
batch_out_tokens,
|
|
len(batch_orig_indices),
|
|
)
|
|
self.graph_misses += len(batch_orig_indices)
|
|
with torch.inference_mode():
|
|
raw = self.model.encoder_eager_forward(batch_mm_kwargs)
|
|
self._scatter_output_slices(
|
|
raw,
|
|
batch_orig_indices,
|
|
per_item_out_tokens,
|
|
outputs_by_orig_idx,
|
|
)
|
|
else:
|
|
logger.debug(
|
|
"Encoder CUDA graph: batch_size=%d, tokens=%d, "
|
|
"budget=%d, waste=%.1f%%",
|
|
len(batch_orig_indices),
|
|
batch_out_tokens,
|
|
token_budget,
|
|
(token_budget - batch_out_tokens) / token_budget * 100,
|
|
)
|
|
replay = self.model.prepare_encoder_cudagraph_replay_buffers(
|
|
batch_mm_kwargs, self.max_batch_size
|
|
)
|
|
|
|
# graph_hits counted inside _run_budget_graph after replay.
|
|
output = self._run_budget_graph(
|
|
batch_mm_kwargs, token_budget, replay.buffers
|
|
)
|
|
assert output is not None
|
|
self._scatter_output_slices(
|
|
output,
|
|
batch_orig_indices,
|
|
per_item_out_tokens,
|
|
outputs_by_orig_idx,
|
|
clone=True,
|
|
)
|
|
|
|
# Return in original batch order (caller maps outputs to token positions)
|
|
return [outputs_by_orig_idx[i] for i in range(num_items)]
|
|
|
|
def _dp_shard(
|
|
self,
|
|
mm_kwargs: dict[str, Any],
|
|
per_item_out_tokens: list[int],
|
|
) -> tuple[dict[str, Any], list[int], list[int], int]:
|
|
"""Distribute items across TP ranks for data-parallel execution.
|
|
|
|
Uses get_load_balance_assignment() to balance load by input size,
|
|
then select_encoder_cudagraph_items() to extract each rank's inputs.
|
|
|
|
Returns:
|
|
local_mm_kwargs: Inputs for this rank.
|
|
image_rank_assignment: Flattened assignment order across all ranks.
|
|
images_per_rank: Number of items per rank.
|
|
max_output_tokens_per_rank: Max output tokens across all ranks
|
|
(for padding during all_gather).
|
|
"""
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
current_rank = get_tensor_model_parallel_rank()
|
|
|
|
per_item_input_sizes = self.model.get_encoder_cudagraph_per_item_input_sizes(
|
|
mm_kwargs
|
|
)
|
|
|
|
(image_rank_assignment, images_per_rank, input_patches_per_rank) = (
|
|
get_load_balance_assignment(per_item_input_sizes, tp_size)
|
|
)
|
|
|
|
# Extract local indices for this rank
|
|
cum_images_per_rank = [0]
|
|
for count in images_per_rank:
|
|
cum_images_per_rank.append(cum_images_per_rank[-1] + count)
|
|
|
|
local_indices = image_rank_assignment[
|
|
cum_images_per_rank[current_rank] : cum_images_per_rank[current_rank + 1]
|
|
]
|
|
|
|
if len(local_indices) > 0:
|
|
local_mm_kwargs = self.model.select_encoder_cudagraph_items(
|
|
mm_kwargs, local_indices
|
|
)
|
|
else:
|
|
local_mm_kwargs = self.model.select_encoder_cudagraph_items(mm_kwargs, [])
|
|
|
|
max_output_tokens_per_rank = (
|
|
max(
|
|
sum(
|
|
per_item_out_tokens[i]
|
|
for i in image_rank_assignment[
|
|
cum_images_per_rank[r] : cum_images_per_rank[r + 1]
|
|
]
|
|
)
|
|
for r in range(tp_size)
|
|
)
|
|
if len(per_item_out_tokens) > 0
|
|
else 0
|
|
)
|
|
|
|
return (
|
|
local_mm_kwargs,
|
|
image_rank_assignment,
|
|
images_per_rank,
|
|
max_output_tokens_per_rank,
|
|
)
|
|
|
|
def _dp_gather(
|
|
self,
|
|
local_outputs: list[torch.Tensor],
|
|
per_item_out_tokens: list[int],
|
|
image_rank_assignment: list[int],
|
|
images_per_rank: list[int],
|
|
max_output_tokens_per_rank: int,
|
|
) -> list[torch.Tensor]:
|
|
"""Gather outputs from all TP ranks and reorder to original sequence.
|
|
|
|
Assumes 2D output tensors [tokens, hidden]. Follows the same
|
|
pad -> all_gather -> unpad -> reorder algorithm as
|
|
run_dp_sharded_mrope_vision_model() in the eager path.
|
|
"""
|
|
hidden_size = self.config.out_hidden_size
|
|
tp_size = len(images_per_rank)
|
|
|
|
if len(local_outputs) > 0:
|
|
local_concat = torch.cat(local_outputs, dim=0)
|
|
else:
|
|
local_concat = torch.empty(
|
|
(0, hidden_size), device=self.device, dtype=self.dtype
|
|
)
|
|
|
|
# Pad to max_output_tokens_per_rank for all_gather
|
|
current_len = local_concat.shape[0]
|
|
if current_len < max_output_tokens_per_rank:
|
|
padding = torch.empty(
|
|
(max_output_tokens_per_rank - current_len, hidden_size),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
)
|
|
local_padded = torch.cat([local_concat, padding], dim=0)
|
|
else:
|
|
local_padded = local_concat
|
|
|
|
gathered = tensor_model_parallel_all_gather(local_padded, dim=0)
|
|
|
|
# Unpad each rank's contribution
|
|
rank_outputs: list[torch.Tensor] = []
|
|
current_idx = 0
|
|
for rank in range(tp_size):
|
|
start = rank * max_output_tokens_per_rank
|
|
rank_count = images_per_rank[rank]
|
|
rank_indices = image_rank_assignment[current_idx : current_idx + rank_count]
|
|
rank_tokens = sum(per_item_out_tokens[i] for i in rank_indices)
|
|
current_idx += rank_count
|
|
rank_outputs.append(gathered[start : start + rank_tokens])
|
|
|
|
# Reorder to original sequence
|
|
total_items = len(per_item_out_tokens)
|
|
result: list[torch.Tensor | None] = [None] * total_items
|
|
current_idx = 0
|
|
for rank in range(tp_size):
|
|
count = images_per_rank[rank]
|
|
if count > 0:
|
|
rank_items = image_rank_assignment[current_idx : current_idx + count]
|
|
self._scatter_output_slices(
|
|
rank_outputs[rank],
|
|
rank_items,
|
|
per_item_out_tokens,
|
|
result,
|
|
)
|
|
current_idx += count
|
|
|
|
return [t for t in result if t is not None]
|
|
|
|
def execute(
|
|
self,
|
|
mm_kwargs: dict[str, Any],
|
|
) -> list[torch.Tensor]:
|
|
"""Execute encoder using CUDA graph with optional DP.
|
|
|
|
Args:
|
|
mm_kwargs: Multimodal keyword arguments containing the
|
|
input tensor and grid dimensions.
|
|
|
|
Returns:
|
|
List of encoder outputs (one per item).
|
|
"""
|
|
if self.use_dp:
|
|
per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)
|
|
|
|
(
|
|
local_mm_kwargs,
|
|
image_rank_assignment,
|
|
images_per_rank,
|
|
max_output_tokens_per_rank,
|
|
) = self._dp_shard(mm_kwargs, per_item_out_tokens)
|
|
|
|
local_outputs = self._execute_local(local_mm_kwargs)
|
|
|
|
result = self._dp_gather(
|
|
local_outputs,
|
|
per_item_out_tokens,
|
|
image_rank_assignment,
|
|
images_per_rank,
|
|
max_output_tokens_per_rank,
|
|
)
|
|
else:
|
|
result = self._execute_local(mm_kwargs)
|
|
|
|
# Log cumulative stats periodically
|
|
stats = self.get_cumulative_stats()
|
|
total_requests = self.graph_hits + self.graph_misses
|
|
if total_requests > 0 and total_requests % self.log_stats_interval == 0:
|
|
logger.debug(
|
|
"Encoder CUDA graph cumulative stats: "
|
|
"hits=%d, misses=%d, hit_rate=%.1f%%",
|
|
stats["graph_hits"],
|
|
stats["graph_misses"],
|
|
stats["hit_rate"] * 100,
|
|
)
|
|
|
|
return result
|
|
|
|
def get_cumulative_stats(self) -> dict[str, Any]:
|
|
"""Get cumulative CUDA graph statistics."""
|
|
total_requests = self.graph_hits + self.graph_misses
|
|
hit_rate = self.graph_hits / total_requests if total_requests > 0 else 0.0
|
|
|
|
return {
|
|
"graph_hits": self.graph_hits,
|
|
"graph_misses": self.graph_misses,
|
|
"hit_rate": hit_rate,
|
|
"num_budgets": len(self.budget_graphs),
|
|
"token_budgets": self.token_budgets,
|
|
}
|