394 lines
16 KiB
Python
394 lines
16 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from tqdm import tqdm
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.compilation import CUDAGraphMode
|
|
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
|
from vllm.forward_context import BatchDescriptor, set_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.offloader.base import get_offloader
|
|
from vllm.platforms import current_platform
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
|
from vllm.v1.worker.gpu.block_table import BlockTables
|
|
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
|
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
|
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
|
from vllm.v1.worker.utils import AttentionGroup
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BatchExecutionDescriptor:
|
|
"""Describes the shape of the batch and CG mode to run; this is used to make shape
|
|
matches between the capture and runtime."""
|
|
|
|
cg_mode: CUDAGraphMode
|
|
num_tokens: int
|
|
num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs)
|
|
uniform_token_count: int | None = None
|
|
|
|
|
|
def _is_compatible(
|
|
desc: BatchExecutionDescriptor,
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
uniform_token_count: int | None,
|
|
) -> bool:
|
|
# desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
|
|
# desc.num_reqs=None means no request padding needed (PIECEWISE)
|
|
return (
|
|
(
|
|
desc.uniform_token_count is None
|
|
or desc.uniform_token_count == uniform_token_count
|
|
)
|
|
and (desc.num_reqs is None or desc.num_reqs >= num_reqs)
|
|
and desc.num_tokens >= num_tokens
|
|
)
|
|
|
|
|
|
def get_uniform_token_count(
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
max_query_len: int,
|
|
) -> int | None:
|
|
"""
|
|
Return the uniform token count if batch is uniform, else None.
|
|
A batch is uniform if all requests have the same number of tokens.
|
|
"""
|
|
if (max_query_len == num_tokens // num_reqs) and (
|
|
num_tokens == max_query_len * num_reqs
|
|
):
|
|
return max_query_len
|
|
return None
|
|
|
|
|
|
class CudaGraphManager:
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
cudagraph_mode: CUDAGraphMode,
|
|
decode_query_len: int,
|
|
):
|
|
self.vllm_config = vllm_config
|
|
self.device = device
|
|
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
|
self.compilation_config = vllm_config.compilation_config
|
|
assert self.compilation_config is not None
|
|
self.cudagraph_mode = cudagraph_mode
|
|
self.decode_query_len = decode_query_len
|
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
|
|
|
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
|
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
|
|
|
self._graphs_captured = False
|
|
self._candidates: list[list[BatchExecutionDescriptor]] = []
|
|
self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
|
|
self._init_candidates()
|
|
|
|
def _init_candidates(self) -> None:
|
|
"""Build priority-ordered candidate lists for each token count."""
|
|
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
|
if not (self.cudagraph_mode and capture_sizes):
|
|
return
|
|
|
|
capture_sizes = sorted(capture_sizes)
|
|
max_decode_tokens = self.max_num_reqs * self.decode_query_len
|
|
decode_mode = self.cudagraph_mode.decode_mode()
|
|
mixed_mode = self.cudagraph_mode.mixed_mode()
|
|
separate_decode_routine = self.cudagraph_mode.separate_routine()
|
|
|
|
descs_by_token_count = defaultdict(list)
|
|
descs_by_mode = defaultdict(list)
|
|
|
|
for num_tokens in capture_sizes:
|
|
# Capture uniform decode specfifc graphs if required
|
|
# (i.e. separate decode routine)
|
|
if (
|
|
separate_decode_routine
|
|
and decode_mode
|
|
and self.decode_query_len <= num_tokens <= max_decode_tokens
|
|
):
|
|
desc = BatchExecutionDescriptor(
|
|
cg_mode=decode_mode,
|
|
num_tokens=num_tokens,
|
|
num_reqs=num_tokens // self.decode_query_len,
|
|
uniform_token_count=self.decode_query_len,
|
|
)
|
|
descs_by_mode[decode_mode].append(desc)
|
|
descs_by_token_count[num_tokens].append(desc)
|
|
|
|
if mixed_mode:
|
|
# for PIECEWISE graphs there is no limit on requests when replaying
|
|
# i.e. no request padding is needed
|
|
# so we leave it as None
|
|
num_reqs = (
|
|
min(num_tokens, self.max_num_reqs)
|
|
if mixed_mode == CUDAGraphMode.FULL
|
|
else None
|
|
)
|
|
desc = BatchExecutionDescriptor(
|
|
cg_mode=mixed_mode,
|
|
num_tokens=num_tokens,
|
|
num_reqs=num_reqs,
|
|
)
|
|
descs_by_mode[mixed_mode].append(desc)
|
|
descs_by_token_count[num_tokens].append(desc)
|
|
|
|
if not descs_by_token_count:
|
|
return
|
|
|
|
sorted_padded = sorted(descs_by_token_count.keys())
|
|
self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]
|
|
|
|
current_range_start = 0
|
|
for cg_size in sorted_padded:
|
|
for i in range(current_range_start, cg_size + 1):
|
|
self._candidates[i] = descs_by_token_count[cg_size]
|
|
current_range_start = cg_size + 1
|
|
|
|
for mode, descs in descs_by_mode.items():
|
|
descs.sort(key=lambda d: d.num_tokens, reverse=True)
|
|
self._capture_descs[mode] = descs
|
|
|
|
def needs_capture(self) -> bool:
|
|
return len(self._capture_descs) > 0
|
|
|
|
@torch.inference_mode()
|
|
def capture(
|
|
self,
|
|
create_forward_fn: Callable[
|
|
[BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
|
|
],
|
|
progress_bar_desc: str = "Capturing CUDA graphs",
|
|
) -> None:
|
|
"""Capture CUDA graphs.
|
|
|
|
Args:
|
|
create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
|
|
returns a function that runs forward with a given CUDAGraphMode.
|
|
"""
|
|
with graph_capture(device=self.device):
|
|
# Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
|
|
# activations so FULL activations should fit in already allocated
|
|
# buffers in the graph pool.
|
|
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
|
|
if mode not in self._capture_descs:
|
|
continue
|
|
|
|
descs = self._capture_descs[mode]
|
|
if is_global_first_rank():
|
|
descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
|
|
for desc in descs:
|
|
# Prepare inputs and get forward function
|
|
forward_fn = create_forward_fn(desc)
|
|
|
|
# Warmup
|
|
forward_fn(CUDAGraphMode.NONE)
|
|
|
|
# Capture
|
|
logger.debug(
|
|
"CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
|
|
)
|
|
if desc.cg_mode == CUDAGraphMode.PIECEWISE:
|
|
forward_fn(CUDAGraphMode.PIECEWISE)
|
|
else:
|
|
assert desc not in self.graphs, (
|
|
f"Graph already captured for {desc}"
|
|
)
|
|
graph = torch.cuda.CUDAGraph()
|
|
# Sync offloader's copy stream before capture.
|
|
# Ensure any pre-capture prefetches from offloader are complete.
|
|
get_offloader().sync_prev_onload()
|
|
with torch.cuda.graph(graph, self.pool):
|
|
forward_fn(CUDAGraphMode.NONE)
|
|
# Join offloader's copy stream after forward to avoid
|
|
# unjoined stream error. The last layer's start_prefetch
|
|
# forks copy_stream, but wait_prefetch only happens in
|
|
# the next forward pass.
|
|
get_offloader().join_after_forward()
|
|
self.graphs[desc] = graph
|
|
self._graphs_captured = True
|
|
|
|
def dispatch(
|
|
self,
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
uniform_token_count: int | None,
|
|
) -> BatchExecutionDescriptor:
|
|
"""Find matching cudagraph descriptor from priority-ordered candidates."""
|
|
if self._graphs_captured and 0 < num_tokens < len(self._candidates):
|
|
for desc in self._candidates[num_tokens]:
|
|
if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
|
|
return desc
|
|
return BatchExecutionDescriptor(
|
|
cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
|
|
)
|
|
|
|
def run_fullgraph(self, desc: BatchExecutionDescriptor):
|
|
"""Replay a captured FULL cudagraph."""
|
|
assert desc.cg_mode == CUDAGraphMode.FULL, (
|
|
f"Expected FULL mode, got {desc.cg_mode}"
|
|
)
|
|
assert desc in self.graphs, f"No cudagraph for {desc}"
|
|
# Sync offloader before replay - needed when transitioning from
|
|
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
|
# The previous eager iteration's start_prefetch may have queued
|
|
# H2D copies on copy_stream that the graph's captured events
|
|
# cannot see. Without this, replay could overwrite static buffers
|
|
# while those copies are still in flight.
|
|
get_offloader().sync_prev_onload()
|
|
self.graphs[desc].replay()
|
|
|
|
|
|
class ModelCudaGraphManager(CudaGraphManager):
|
|
"""CudaGraphManager with model-specific capture and hidden state management."""
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
cudagraph_mode: CUDAGraphMode,
|
|
decode_query_len: int,
|
|
):
|
|
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
|
|
self.hidden_states: torch.Tensor | None = None
|
|
self.aux_hidden_states: list[torch.Tensor] = []
|
|
self.use_aux_hidden_state_outputs = False
|
|
|
|
def capture(
|
|
self,
|
|
model: nn.Module,
|
|
model_state: ModelState,
|
|
input_buffers: InputBuffers,
|
|
block_tables: BlockTables,
|
|
attn_groups: list[list[AttentionGroup]],
|
|
kv_cache_config: KVCacheConfig,
|
|
has_lora: bool = False,
|
|
use_aux_hidden_state_outputs: bool = False,
|
|
progress_bar_desc: str = "Capturing CUDA graphs",
|
|
) -> None:
|
|
"""Capture CUDA graphs for model forward pass."""
|
|
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
|
|
|
def create_forward_fn(
|
|
desc: BatchExecutionDescriptor,
|
|
) -> Callable[[CUDAGraphMode], None]:
|
|
num_tokens = desc.num_tokens
|
|
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
|
|
num_tokens_across_dp = (
|
|
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
|
if self.dp_size > 1
|
|
else None
|
|
)
|
|
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
|
num_reqs,
|
|
num_tokens,
|
|
model_state,
|
|
input_buffers,
|
|
block_tables,
|
|
attn_groups,
|
|
kv_cache_config,
|
|
)
|
|
|
|
def forward_fn(cg_mode: CUDAGraphMode) -> None:
|
|
batch_descriptor = (
|
|
BatchDescriptor(num_tokens=num_tokens)
|
|
if cg_mode == CUDAGraphMode.PIECEWISE
|
|
else None
|
|
)
|
|
with set_forward_context(
|
|
attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
|
|
self.vllm_config,
|
|
num_tokens=num_tokens,
|
|
cudagraph_runtime_mode=cg_mode,
|
|
num_tokens_across_dp=num_tokens_across_dp,
|
|
slot_mapping=slot_mappings,
|
|
batch_descriptor=batch_descriptor,
|
|
):
|
|
model_inputs = {
|
|
"input_ids": input_buffers.input_ids[:num_tokens],
|
|
"positions": input_buffers.positions[:num_tokens],
|
|
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
|
}
|
|
model_output = model(**model_inputs)
|
|
if self.use_aux_hidden_state_outputs:
|
|
hidden_states, aux_hidden_states = model_output
|
|
else:
|
|
hidden_states = model_output
|
|
aux_hidden_states = []
|
|
if self.hidden_states is None:
|
|
self.hidden_states = torch.empty_like(hidden_states)
|
|
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
|
self.aux_hidden_states = [
|
|
torch.empty_like(x) for x in aux_hidden_states
|
|
]
|
|
self.hidden_states[:num_tokens] = hidden_states
|
|
for i, aux in enumerate(aux_hidden_states):
|
|
self.aux_hidden_states[i][:num_tokens] = aux
|
|
|
|
return forward_fn
|
|
|
|
super().capture(create_forward_fn, progress_bar_desc)
|
|
|
|
def run_fullgraph(
|
|
self, desc: BatchExecutionDescriptor
|
|
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
|
"""Replay a captured FULL cudagraph and return hidden states."""
|
|
super().run_fullgraph(desc)
|
|
assert self.hidden_states is not None
|
|
hidden_states = self.hidden_states[: desc.num_tokens]
|
|
if not self.use_aux_hidden_state_outputs:
|
|
return hidden_states
|
|
return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]
|
|
|
|
|
|
def prepare_inputs_to_capture(
|
|
num_reqs: int,
|
|
num_tokens: int,
|
|
model_state: ModelState,
|
|
input_buffers: InputBuffers,
|
|
block_tables: BlockTables,
|
|
attn_groups: list[list[AttentionGroup]],
|
|
kv_cache_config: KVCacheConfig,
|
|
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
|
input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
|
|
input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
|
|
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
|
|
slot_mappings_by_layer = build_slot_mappings_by_layer(
|
|
slot_mappings, kv_cache_config
|
|
)
|
|
|
|
# HACK(woosuk): Special handling for DCP.
|
|
if block_tables.cp_size > 1:
|
|
prepare_dcp_local_seq_lens(
|
|
input_buffers.dcp_local_seq_lens,
|
|
input_batch.seq_lens,
|
|
num_reqs,
|
|
block_tables.cp_size,
|
|
block_tables.cp_rank,
|
|
block_tables.cp_interleave,
|
|
)
|
|
input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
|
|
|
|
attn_metadata = model_state.prepare_attn(
|
|
input_batch,
|
|
CUDAGraphMode.NONE,
|
|
input_block_tables,
|
|
slot_mappings,
|
|
attn_groups,
|
|
kv_cache_config,
|
|
)
|
|
return attn_metadata, slot_mappings_by_layer
|