Files
vllm/vllm/v1/worker/gpu/cudagraph_utils.py
2025-11-23 11:15:32 -08:00

206 lines
7.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import patch
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBuffers
class CudaGraphManager:
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
):
self.vllm_config = vllm_config
self.scheduler_config = vllm_config.scheduler_config
self.device = device
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
if self.compilation_config.cudagraph_mode is None:
self.cudagraph_mode = CUDAGraphMode.NONE
else:
self.cudagraph_mode = self.compilation_config.cudagraph_mode
if self.compilation_config.cudagraph_capture_sizes is not None:
cudagraph_sizes = sorted(self.compilation_config.cudagraph_capture_sizes)
# Limit the cudagraph sizes to the max decode batch size.
self.cudagraph_sizes = [
x for x in cudagraph_sizes if x <= self.max_num_reqs
]
else:
self.cudagraph_sizes = []
self.padded_sizes = self._init_padded_sizes()
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def _init_padded_sizes(self) -> dict[int, int]:
if not self.cudagraph_mode.has_full_cudagraphs():
# Full cuda graphs are not used.
return {}
if not self.cudagraph_sizes:
return {}
padded_sizes: dict[int, int] = {}
for i in range(1, self.cudagraph_sizes[-1] + 1):
for x in self.cudagraph_sizes:
if i <= x:
padded_sizes[i] = x
break
return padded_sizes
def needs_capture(self) -> bool:
return len(self.padded_sizes) > 0
def get_cudagraph_size(
self,
scheduler_output: SchedulerOutput,
num_tokens_after_padding: int,
) -> int | None:
if not self.cudagraph_mode.has_full_cudagraphs():
return None
if self.cudagraph_mode != CUDAGraphMode.FULL:
# TODO(woosuk): Support uniform decode with multiple tokens (spec decoding).
all_decode = all(
x == 1 for x in scheduler_output.num_scheduled_tokens.values()
)
if not all_decode:
# Prefill is included.
return None
return self.padded_sizes.get(num_tokens_after_padding)
def capture_graph(
self,
batch_size: int,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert batch_size not in self.graphs
# Prepare dummy inputs.
input_ids = input_buffers.input_ids.gpu[:batch_size]
positions = input_buffers.positions[:batch_size]
input_buffers.query_start_loc.np[: batch_size + 1] = np.arange(batch_size + 1)
input_buffers.query_start_loc.np[batch_size:] = batch_size
input_buffers.query_start_loc.copy_to_gpu()
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
# seq_lens_np (CPU), which might cause issues in some attention backends.
input_buffers.seq_lens[:batch_size] = 1
input_buffers.seq_lens[batch_size:] = 0
input_block_tables = [x[:batch_size] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :batch_size]
attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders,
num_reqs=batch_size,
num_tokens=batch_size,
query_start_loc=input_buffers.query_start_loc,
seq_lens=input_buffers.seq_lens,
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
num_computed_tokens_cpu=None, # FIXME
block_tables=input_block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
)
if self.dp_size > 1:
num_tokens_across_dp = torch.full(
(self.dp_size,),
batch_size,
dtype=torch.int32,
device="cpu",
)
else:
num_tokens_across_dp = None
# Warm up.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)
# Capture the graph.
graph = torch.cuda.CUDAGraph()
with (
patch("torch.cuda.empty_cache", lambda: None),
set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=batch_size,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
),
torch.cuda.graph(graph, self.pool),
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
)
self.hidden_states[:batch_size] = hidden_states
self.graphs[batch_size] = graph
@torch.inference_mode()
def capture(
self,
model: nn.Module,
input_buffers: InputBuffers,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
) -> None:
assert self.needs_capture()
# Capture larger graphs first.
sizes_to_capture = sorted(self.cudagraph_sizes, reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
with graph_capture(device=self.device):
for batch_size in sizes_to_capture:
self.capture_graph(
batch_size,
model,
input_buffers,
block_tables,
attn_metadata_builders,
kv_cache_config,
)
def run(self, batch_size: int) -> torch.Tensor:
assert batch_size in self.graphs
self.graphs[batch_size].replay()
assert self.hidden_states is not None
return self.hidden_states[:batch_size]