[MRV2] Extensible CG dispatch rework (#35959)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
@@ -97,6 +97,9 @@ class CUDAGraphMode(enum.Enum):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return self != CUDAGraphMode.NONE
|
||||
|
||||
|
||||
@config
|
||||
class PassConfig:
|
||||
|
||||
@@ -104,19 +104,24 @@ class BlockTables:
|
||||
self.num_blocks.copy_to_uva()
|
||||
|
||||
def gather_block_tables(
|
||||
self, idx_mapping: torch.Tensor
|
||||
self,
|
||||
idx_mapping: torch.Tensor,
|
||||
num_reqs_padded: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)](
|
||||
# Launch kernel with num_reqs_padded to fuse zeroing of padded rows.
|
||||
_gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)](
|
||||
idx_mapping,
|
||||
self.block_table_ptrs,
|
||||
self.input_block_table_ptrs,
|
||||
self.block_table_strides,
|
||||
self.num_blocks.gpu,
|
||||
self.num_blocks.gpu.stride(0),
|
||||
num_reqs,
|
||||
self.input_block_tables[0].shape[1], # max_num_blocks
|
||||
BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
|
||||
return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables)
|
||||
|
||||
def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]:
|
||||
# NOTE(woosuk): The output may be used for CUDA graph capture.
|
||||
@@ -130,6 +135,7 @@ class BlockTables:
|
||||
idx_mapping: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
num_tokens_padded: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_tokens = positions.shape[0]
|
||||
@@ -151,7 +157,7 @@ class BlockTables:
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
TRITON_BLOCK_SIZE=1024, # type: ignore
|
||||
)
|
||||
return self.slot_mappings[:, :num_tokens]
|
||||
return self.slot_mappings[:, :num_tokens_padded]
|
||||
|
||||
def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor:
|
||||
# Fill the entire slot_mappings tensor, not just the first `num_tokens` entries.
|
||||
@@ -173,21 +179,31 @@ def _gather_block_tables_kernel(
|
||||
block_table_strides, # [num_kv_cache_groups]
|
||||
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
|
||||
num_blocks_stride,
|
||||
num_reqs, # actual number of requests (for padding)
|
||||
max_num_blocks, # stride for zeroing padded rows
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# kv cache group id
|
||||
group_id = tl.program_id(0)
|
||||
batch_idx = tl.program_id(1)
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
if batch_idx >= num_reqs:
|
||||
# Zero out padded rows.
|
||||
for i in tl.range(0, max_num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks)
|
||||
return
|
||||
|
||||
req_idx = tl.load(batch_idx_to_req_idx + batch_idx)
|
||||
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
|
||||
num_blocks = tl.load(group_num_blocks_ptr + req_idx)
|
||||
|
||||
stride = tl.load(block_table_strides + group_id)
|
||||
src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32)
|
||||
src_row_ptr = src_block_table_ptr + req_idx * stride
|
||||
dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32)
|
||||
dst_row_ptr = dst_block_table_ptr + batch_idx * stride
|
||||
|
||||
for i in tl.range(0, num_blocks, BLOCK_SIZE):
|
||||
offset = i + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -11,235 +13,260 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BatchExecutionDescriptor:
|
||||
"""Describes the shape of the batch and CG mode to run; this is used to make shape
|
||||
matches between the capture and runtime."""
|
||||
|
||||
cg_mode: CUDAGraphMode
|
||||
num_tokens: int
|
||||
num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs)
|
||||
uniform_token_count: int | None = None
|
||||
|
||||
|
||||
def _is_compatible(
|
||||
desc: BatchExecutionDescriptor,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
uniform_token_count: int | None,
|
||||
) -> bool:
|
||||
# desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count
|
||||
# desc.num_reqs=None means no request padding needed (PIECEWISE)
|
||||
return (
|
||||
(
|
||||
desc.uniform_token_count is None
|
||||
or desc.uniform_token_count == uniform_token_count
|
||||
)
|
||||
and (desc.num_reqs is None or desc.num_reqs >= num_reqs)
|
||||
and desc.num_tokens >= num_tokens
|
||||
)
|
||||
|
||||
|
||||
def get_uniform_token_count(
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
max_query_len: int,
|
||||
) -> int | None:
|
||||
"""
|
||||
Return the uniform token count if batch is uniform, else None.
|
||||
A batch is uniform if all requests have the same number of tokens.
|
||||
"""
|
||||
if (max_query_len == num_tokens // num_reqs) and (
|
||||
num_tokens == max_query_len * num_reqs
|
||||
):
|
||||
return max_query_len
|
||||
return None
|
||||
|
||||
|
||||
class CudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
use_aux_hidden_state_outputs: bool,
|
||||
device: torch.device,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
decode_query_len: int,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
self.uniform_decode_query_len = 1
|
||||
spec_config = vllm_config.speculative_config
|
||||
if spec_config is not None:
|
||||
self.uniform_decode_query_len += spec_config.num_speculative_tokens
|
||||
|
||||
self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
self.decode_query_len = decode_query_len
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
use_uniform_decode_cudagraph = (
|
||||
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.cudagraph_mode.separate_routine()
|
||||
)
|
||||
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
self.uniform_decode_query_len,
|
||||
use_uniform_decode_cudagraph,
|
||||
)
|
||||
self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
self._graphs_captured = False
|
||||
self._candidates: list[list[BatchExecutionDescriptor]] = []
|
||||
self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {}
|
||||
self._init_candidates()
|
||||
|
||||
def _init_candidates(self) -> None:
|
||||
"""Build priority-ordered candidate lists for each token count."""
|
||||
capture_sizes = self.compilation_config.cudagraph_capture_sizes
|
||||
if not (self.cudagraph_mode and capture_sizes):
|
||||
return
|
||||
|
||||
capture_sizes = sorted(capture_sizes)
|
||||
max_decode_tokens = self.max_num_reqs * self.decode_query_len
|
||||
decode_mode = self.cudagraph_mode.decode_mode()
|
||||
mixed_mode = self.cudagraph_mode.mixed_mode()
|
||||
separate_decode_routine = self.cudagraph_mode.separate_routine()
|
||||
|
||||
descs_by_token_count = defaultdict(list)
|
||||
descs_by_mode = defaultdict(list)
|
||||
|
||||
for num_tokens in capture_sizes:
|
||||
# Capture uniform decode specfifc graphs if required
|
||||
# (i.e. separate decode routine)
|
||||
if (
|
||||
separate_decode_routine
|
||||
and decode_mode
|
||||
and self.decode_query_len <= num_tokens <= max_decode_tokens
|
||||
):
|
||||
desc = BatchExecutionDescriptor(
|
||||
cg_mode=decode_mode,
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_tokens // self.decode_query_len,
|
||||
uniform_token_count=self.decode_query_len,
|
||||
)
|
||||
descs_by_mode[decode_mode].append(desc)
|
||||
descs_by_token_count[num_tokens].append(desc)
|
||||
|
||||
if mixed_mode:
|
||||
# for PIECEWISE graphs there is no limit on requests when replaying
|
||||
# i.e. no request padding is needed
|
||||
# so we leave it as None
|
||||
num_reqs = (
|
||||
min(num_tokens, self.max_num_reqs)
|
||||
if mixed_mode == CUDAGraphMode.FULL
|
||||
else None
|
||||
)
|
||||
desc = BatchExecutionDescriptor(
|
||||
cg_mode=mixed_mode,
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
)
|
||||
descs_by_mode[mixed_mode].append(desc)
|
||||
descs_by_token_count[num_tokens].append(desc)
|
||||
|
||||
if not descs_by_token_count:
|
||||
return
|
||||
|
||||
sorted_padded = sorted(descs_by_token_count.keys())
|
||||
self._candidates = [[] for _ in range(sorted_padded[-1] + 1)]
|
||||
|
||||
current_range_start = 0
|
||||
for cg_size in sorted_padded:
|
||||
for i in range(current_range_start, cg_size + 1):
|
||||
self._candidates[i] = descs_by_token_count[cg_size]
|
||||
current_range_start = cg_size + 1
|
||||
|
||||
for mode, descs in descs_by_mode.items():
|
||||
descs.sort(key=lambda d: d.num_tokens, reverse=True)
|
||||
self._capture_descs[mode] = descs
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
return len(self.cudagraph_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(
|
||||
self, num_tokens: int, uniform_decode: bool = False
|
||||
) -> int | None:
|
||||
if uniform_decode and self.uniform_decode_cudagraph_sizes:
|
||||
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> None:
|
||||
# select and check capture function
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
# prepare inputs
|
||||
if uniform_decode:
|
||||
num_reqs = min(
|
||||
cdiv(num_tokens, self.uniform_decode_query_len),
|
||||
self.max_num_reqs,
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
# NOTE: Values returned by `prepare_dummy_inputs` will override the
|
||||
# default values above.
|
||||
**model_state.prepare_dummy_inputs(num_reqs, num_tokens),
|
||||
}
|
||||
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Allocate output buffers if not already done.
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
||||
self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states]
|
||||
|
||||
capture_fn(
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
model_inputs=model_inputs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
assert attn_metadata is not None
|
||||
# Capture the graph.
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata=attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
),
|
||||
torch.cuda.graph(graph, self.pool),
|
||||
):
|
||||
model_output = model(**model_inputs)
|
||||
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
# Copy outputs to the output buffers.
|
||||
assert self.hidden_states is not None
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
for i, aux_hidden in enumerate(aux_hidden_states):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux_hidden
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
model_inputs: dict[str, torch.Tensor | None],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
# create batch descriptor for piecewise cudagraph dispatch key
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
|
||||
|
||||
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
|
||||
with set_forward_context(
|
||||
attn_metadata=None, # piecewise no need attn_metadata
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
model(**model_inputs)
|
||||
return len(self._capture_descs) > 0
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
create_forward_fn: Callable[
|
||||
[BatchExecutionDescriptor], Callable[[CUDAGraphMode], None]
|
||||
],
|
||||
progress_bar_desc: str = "Capturing CUDA graphs",
|
||||
) -> None:
|
||||
"""Capture CUDA graphs.
|
||||
|
||||
Args:
|
||||
create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and
|
||||
returns a function that runs forward with a given CUDAGraphMode.
|
||||
"""
|
||||
with graph_capture(device=self.device):
|
||||
# Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger
|
||||
# activations so FULL activations should fit in already allocated
|
||||
# buffers in the graph pool.
|
||||
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
|
||||
if mode not in self._capture_descs:
|
||||
continue
|
||||
|
||||
descs = self._capture_descs[mode]
|
||||
if is_global_first_rank():
|
||||
descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})")
|
||||
for desc in descs:
|
||||
# Prepare inputs and get forward function
|
||||
forward_fn = create_forward_fn(desc)
|
||||
|
||||
# Warmup
|
||||
forward_fn(CUDAGraphMode.NONE)
|
||||
|
||||
# Capture
|
||||
logger.debug(
|
||||
"CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc
|
||||
)
|
||||
if desc.cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
forward_fn(CUDAGraphMode.PIECEWISE)
|
||||
else:
|
||||
assert desc not in self.graphs, (
|
||||
f"Graph already captured for {desc}"
|
||||
)
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
forward_fn(CUDAGraphMode.NONE)
|
||||
# Join offloader's copy stream after forward to avoid
|
||||
# unjoined stream error. The last layer's start_prefetch
|
||||
# forks copy_stream, but wait_prefetch only happens in
|
||||
# the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
self.graphs[desc] = graph
|
||||
self._graphs_captured = True
|
||||
|
||||
def dispatch(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
uniform_token_count: int | None,
|
||||
) -> BatchExecutionDescriptor:
|
||||
"""Find matching cudagraph descriptor from priority-ordered candidates."""
|
||||
if self._graphs_captured and 0 < num_tokens < len(self._candidates):
|
||||
for desc in self._candidates[num_tokens]:
|
||||
if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count):
|
||||
return desc
|
||||
return BatchExecutionDescriptor(
|
||||
cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs
|
||||
)
|
||||
|
||||
def run_fullgraph(self, desc: BatchExecutionDescriptor):
|
||||
"""Replay a captured FULL cudagraph."""
|
||||
assert desc.cg_mode == CUDAGraphMode.FULL, (
|
||||
f"Expected FULL mode, got {desc.cg_mode}"
|
||||
)
|
||||
assert desc in self.graphs, f"No cudagraph for {desc}"
|
||||
# Sync offloader before replay - needed when transitioning from
|
||||
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
||||
# The previous eager iteration's start_prefetch may have queued
|
||||
# H2D copies on copy_stream that the graph's captured events
|
||||
# cannot see. Without this, replay could overwrite static buffers
|
||||
# while those copies are still in flight.
|
||||
get_offloader().sync_prev_onload()
|
||||
self.graphs[desc].replay()
|
||||
|
||||
|
||||
class ModelCudaGraphManager(CudaGraphManager):
|
||||
"""CudaGraphManager with model-specific capture and hidden state management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
decode_query_len: int,
|
||||
):
|
||||
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len)
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
self.aux_hidden_states: list[torch.Tensor] = []
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
|
||||
def capture(
|
||||
self,
|
||||
model: nn.Module,
|
||||
@@ -249,139 +276,81 @@ class CudaGraphManager:
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
use_aux_hidden_state_outputs: bool = False,
|
||||
progress_bar_desc: str = "Capturing CUDA graphs",
|
||||
) -> None:
|
||||
common_kwargs = dict(
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
"""Capture CUDA graphs for model forward pass."""
|
||||
self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs
|
||||
|
||||
# Phase 1: Capture for mixed prefill-decode batches if needed.
|
||||
mixed_mode = self.cudagraph_mode.mixed_mode()
|
||||
if mixed_mode != CUDAGraphMode.NONE:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.cudagraph_sizes,
|
||||
capture_cudagraph_mode=mixed_mode,
|
||||
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
|
||||
uniform_decode=False,
|
||||
**common_kwargs,
|
||||
def create_forward_fn(
|
||||
desc: BatchExecutionDescriptor,
|
||||
) -> Callable[[CUDAGraphMode], None]:
|
||||
num_tokens = desc.num_tokens
|
||||
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_across_dp = (
|
||||
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
if self.dp_size > 1
|
||||
else None
|
||||
)
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
|
||||
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
|
||||
# This is only needed if we use a separate routine for decode batches
|
||||
# and the decode_mode is FULL.
|
||||
if self.uniform_decode_cudagraph_sizes:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
|
||||
capture_cudagraph_mode=CUDAGraphMode.FULL,
|
||||
desc="Capturing CUDA graphs (decode, FULL)",
|
||||
uniform_decode=True,
|
||||
**common_kwargs,
|
||||
)
|
||||
def forward_fn(cg_mode: CUDAGraphMode) -> None:
|
||||
batch_descriptor = (
|
||||
BatchDescriptor(num_tokens=num_tokens)
|
||||
if cg_mode == CUDAGraphMode.PIECEWISE
|
||||
else None
|
||||
)
|
||||
with set_forward_context(
|
||||
attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=cg_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
model_inputs = {
|
||||
"input_ids": input_buffers.input_ids[:num_tokens],
|
||||
"positions": input_buffers.positions[:num_tokens],
|
||||
}
|
||||
model_output = model(**model_inputs)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = []
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
if self.use_aux_hidden_state_outputs and not self.aux_hidden_states:
|
||||
self.aux_hidden_states = [
|
||||
torch.empty_like(x) for x in aux_hidden_states
|
||||
]
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
for i, aux in enumerate(aux_hidden_states):
|
||||
self.aux_hidden_states[i][:num_tokens] = aux
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_reqs: int, num_tokens: int, max_query_len: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_tokens == max_query_len * num_reqs
|
||||
)
|
||||
return forward_fn
|
||||
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif is_uniform_decode:
|
||||
cudagraph_mode = self.cudagraph_mode.decode_mode()
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode.mixed_mode()
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
super().capture(create_forward_fn, progress_bar_desc)
|
||||
|
||||
def run_fullgraph(
|
||||
self, num_tokens: int
|
||||
self, desc: BatchExecutionDescriptor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
|
||||
# Sync offloader before replay - needed when transitioning from
|
||||
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
||||
# The previous eager iteration's start_prefetch may have queued
|
||||
# H2D copies on copy_stream that the graph's captured events
|
||||
# cannot see. Without this, replay could overwrite static buffers
|
||||
# while those copies are still in flight.
|
||||
get_offloader().sync_prev_onload()
|
||||
self.graphs[num_tokens].replay()
|
||||
"""Replay a captured FULL cudagraph and return hidden states."""
|
||||
super().run_fullgraph(desc)
|
||||
assert self.hidden_states is not None
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
hidden_states = self.hidden_states[: desc.num_tokens]
|
||||
if not self.use_aux_hidden_state_outputs:
|
||||
return hidden_states
|
||||
return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states]
|
||||
|
||||
|
||||
def get_cudagraph_sizes(
|
||||
capture_sizes: list[int] | None,
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
uniform_decode_query_len: int = 1,
|
||||
uniform_decode_cudagraph: bool = False,
|
||||
) -> tuple[dict[int, int], dict[int, int]]:
|
||||
# Support both FULL and PIECEWISE cudagraph modes
|
||||
if cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return {}, {}
|
||||
if not capture_sizes:
|
||||
return {}, {}
|
||||
|
||||
capture_sizes = sorted(capture_sizes)
|
||||
if not capture_sizes:
|
||||
return {}, {}
|
||||
|
||||
cudagraph_sizes: dict[int, int] = {}
|
||||
for i in range(1, capture_sizes[-1] + 1):
|
||||
for x in capture_sizes:
|
||||
if i <= x:
|
||||
cudagraph_sizes[i] = x
|
||||
break
|
||||
|
||||
uniform_decode_cudagraph_sizes: dict[int, int] = {}
|
||||
if uniform_decode_cudagraph:
|
||||
max_num_tokens = max_num_reqs * uniform_decode_query_len
|
||||
uniform_decode_cudagraph_sizes = {
|
||||
k: v
|
||||
for k, v in cudagraph_sizes.items()
|
||||
if v <= max_num_tokens and v >= uniform_decode_query_len
|
||||
}
|
||||
return cudagraph_sizes, uniform_decode_cudagraph_sizes
|
||||
|
||||
|
||||
def capture_graphs(
|
||||
cudagraph_sizes: dict[int, int],
|
||||
device: torch.device,
|
||||
capture_fn: Callable,
|
||||
capture_cudagraph_mode: CUDAGraphMode,
|
||||
desc: str = "Capturing CUDA graphs",
|
||||
**capture_kwargs,
|
||||
) -> None:
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
|
||||
|
||||
with graph_capture(device=device):
|
||||
for size in sizes_to_capture:
|
||||
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
|
||||
return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states]
|
||||
|
||||
|
||||
def prepare_inputs_to_capture(
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
BatchExecutionDescriptor,
|
||||
CudaGraphManager,
|
||||
)
|
||||
|
||||
|
||||
def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None:
|
||||
@@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
|
||||
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
def sync_cudagraph_and_dp_padding(
|
||||
cudagraph_manager: CudaGraphManager,
|
||||
desired_batch_desc: BatchExecutionDescriptor,
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
cudagraph_runtime_mode: int,
|
||||
num_reqs: int,
|
||||
uniform_token_count: int | None,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]:
|
||||
"""
|
||||
Coordinates the batch descriptor and DP padding across all ranks.
|
||||
|
||||
Returns (synced_batch_desc, num_tokens_across_dp).
|
||||
"""
|
||||
assert dp_size > 1, "DP size must be greater than 1"
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
tensor[2][dp_rank] = cudagraph_runtime_mode
|
||||
tensor[1][dp_rank] = desired_batch_desc.cg_mode.value
|
||||
tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None)
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1], tensor[2]
|
||||
|
||||
num_tokens_across_dp = tensor[0]
|
||||
cg_mode_across_dp = tensor[1]
|
||||
uniform_token_counts_across_dp = tensor[2]
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
num_tokens: int,
|
||||
cudagraph_size: int | None,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[int, torch.Tensor | None, int]:
|
||||
if dp_size == 1:
|
||||
if cudagraph_size is not None:
|
||||
return cudagraph_size, None, cudagraph_runtime_mode
|
||||
else:
|
||||
return num_tokens, None, cudagraph_runtime_mode
|
||||
|
||||
# Convert None to -1 for sync (indicates no cudagraph available)
|
||||
if num_tokens == 0:
|
||||
cudagraph_size = 0
|
||||
elif cudagraph_size is None:
|
||||
cudagraph_size = -1
|
||||
|
||||
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
|
||||
get_batch_metadata_across_dp(
|
||||
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
|
||||
)
|
||||
)
|
||||
if torch.all(num_tokens_across_dp == 0).item():
|
||||
# All ranks have zero tokens to run.
|
||||
return 0, None, 0
|
||||
synced_desc = BatchExecutionDescriptor(
|
||||
cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0
|
||||
)
|
||||
return synced_desc, None
|
||||
|
||||
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
|
||||
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
|
||||
# Check if all ranks have valid cudagraph_size.
|
||||
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
|
||||
synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item()))
|
||||
|
||||
if synced_cudagraph_mode != 0 and all_have_cudagraph:
|
||||
# All ranks use cudagraph. Pad to max cudagraph_size.
|
||||
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
|
||||
num_tokens_across_dp[:] = max_cudagraph_size
|
||||
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
|
||||
else:
|
||||
# Fall back to eager mode (no cudagraph).
|
||||
# Either some rank doesn't have cudagraph size or mode is NONE.
|
||||
synced_cudagraph_mode = 0
|
||||
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
|
||||
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
|
||||
# If any rank wants to run eager, all ranks run eager
|
||||
if synced_cg_mode == CUDAGraphMode.NONE:
|
||||
return BatchExecutionDescriptor(
|
||||
cg_mode=CUDAGraphMode.NONE,
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
), num_tokens_across_dp
|
||||
|
||||
synced_num_tokens = int(num_tokens_across_dp.max().item())
|
||||
synced_uniform_token_count = uniform_token_counts_across_dp[0]
|
||||
# If ranks disagree on the uniform token count, or its 0 (means None) set to None
|
||||
if synced_uniform_token_count == 0 or not torch.all(
|
||||
uniform_token_counts_across_dp == synced_uniform_token_count
|
||||
):
|
||||
synced_uniform_token_count = None
|
||||
|
||||
# Dispatch for the final synced values, use num_reqs instead of synced_num_reqs
|
||||
# so we don't perform request padding for PIECEWISE graphs
|
||||
synced_desc = cudagraph_manager.dispatch(
|
||||
num_reqs, synced_num_tokens, synced_uniform_token_count
|
||||
)
|
||||
|
||||
# Update num_tokens_across_dp to reflect padded size.
|
||||
num_tokens_across_dp[:] = synced_desc.num_tokens
|
||||
|
||||
return synced_desc, num_tokens_across_dp
|
||||
|
||||
@@ -37,6 +37,7 @@ class InputBatch:
|
||||
# batch_idx -> req_id
|
||||
req_ids: list[str]
|
||||
num_reqs: int
|
||||
num_reqs_after_padding: int
|
||||
|
||||
# batch_idx -> req_state_idx
|
||||
idx_mapping: torch.Tensor
|
||||
@@ -123,6 +124,7 @@ class InputBatch:
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
num_reqs_after_padding=num_reqs,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
@@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens(
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_logits: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
# use idx_mapping.shape[0] for actual request count
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
num_speculative_steps = draft_tokens.shape[-1]
|
||||
|
||||
logits_indices = torch.empty(
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu
|
||||
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
BatchExecutionDescriptor,
|
||||
ModelCudaGraphManager,
|
||||
get_uniform_token_count,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
@@ -137,6 +140,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
# Data parallelism.
|
||||
self.dp_size = self.parallel_config.data_parallel_size
|
||||
self.dp_rank = self.parallel_config.data_parallel_rank
|
||||
|
||||
# Decode context parallelism.
|
||||
self.dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
self.use_dcp = self.dcp_size > 1
|
||||
@@ -193,10 +200,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
self.decode_query_len = self.num_speculative_steps + 1
|
||||
self.cudagraph_manager = ModelCudaGraphManager(
|
||||
self.vllm_config,
|
||||
self.use_aux_hidden_state_outputs,
|
||||
self.device,
|
||||
self.compilation_config.cudagraph_mode,
|
||||
decode_query_len=self.decode_query_len,
|
||||
)
|
||||
# Structured outputs worker.
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
@@ -331,17 +340,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
# Create a dummy scheduler output.
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
if uniform_decode:
|
||||
# Align tokens to uniform_decode_query_len for cudagraph
|
||||
# compatibility across DP ranks.
|
||||
query_len = self.cudagraph_manager.uniform_decode_query_len
|
||||
num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs)
|
||||
num_tokens = num_reqs * query_len
|
||||
num_tokens_per_request = [query_len] * num_reqs
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
# HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
|
||||
# and for spec-decode with MTP we want to make sure the dummy runs use
|
||||
# 1+num_speculative_tokens we use max here, this will likely be eventually
|
||||
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
|
||||
num_tokens = max(num_tokens, self.decode_query_len)
|
||||
num_reqs = num_tokens // self.decode_query_len
|
||||
assert num_tokens % self.decode_query_len == 0
|
||||
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
|
||||
num_tokens_per_request[-1] += num_tokens % num_reqs
|
||||
|
||||
assert sum(num_tokens_per_request) == num_tokens
|
||||
num_scheduled_tokens = {
|
||||
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
|
||||
@@ -498,13 +508,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
with self.maybe_setup_dummy_loras(self.lora_config):
|
||||
self.cudagraph_manager.capture(
|
||||
model=self.model,
|
||||
model_state=self.model_state,
|
||||
input_buffers=self.input_buffers,
|
||||
block_tables=self.block_tables,
|
||||
attn_groups=self.attn_groups,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
self.model,
|
||||
self.model_state,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
has_lora=self.lora_config is not None,
|
||||
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
|
||||
)
|
||||
if self.speculator is not None:
|
||||
self.speculator.capture_model()
|
||||
@@ -592,9 +603,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int
|
||||
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
|
||||
) -> InputBatch:
|
||||
num_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_tokens_after_padding = batch_desc.num_tokens
|
||||
assert num_tokens > 0
|
||||
num_tokens_per_req = scheduler_output.num_scheduled_tokens
|
||||
num_reqs = len(num_tokens_per_req)
|
||||
@@ -644,6 +656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
|
||||
# Get query_start_loc.
|
||||
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
|
||||
num_reqs_padded = batch_desc.num_reqs or num_reqs
|
||||
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
|
||||
query_start_loc_np[0] = 0
|
||||
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
|
||||
@@ -651,8 +665,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
|
||||
query_start_loc_np[num_reqs + 1 :] = num_tokens
|
||||
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs + 1]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
|
||||
|
||||
# Get prefill tokens if any.
|
||||
if self.req_states.any_prefills(idx_mapping_np):
|
||||
@@ -674,7 +688,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_buffers.positions,
|
||||
self.input_buffers.seq_lens,
|
||||
)
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
|
||||
|
||||
dcp_local_seq_lens = None
|
||||
if self.use_dcp:
|
||||
@@ -687,7 +701,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.dcp_rank,
|
||||
self.cp_interleave,
|
||||
)
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
|
||||
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
@@ -706,6 +720,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return InputBatch(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
num_reqs_after_padding=num_reqs_padded,
|
||||
idx_mapping=idx_mapping,
|
||||
idx_mapping_np=idx_mapping_np,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
@@ -729,13 +744,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def prepare_attn(
|
||||
self, input_batch: InputBatch
|
||||
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
# Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
|
||||
block_tables = self.block_tables.gather_block_tables(
|
||||
input_batch.idx_mapping,
|
||||
num_reqs_padded=input_batch.num_reqs_after_padding,
|
||||
)
|
||||
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
|
||||
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
input_batch.idx_mapping,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.positions,
|
||||
num_tokens_padded=input_batch.num_tokens_after_padding,
|
||||
)
|
||||
return block_tables, slot_mappings
|
||||
|
||||
@@ -851,27 +871,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
return empty_output
|
||||
|
||||
# Get local cudagraph mode and size.
|
||||
local_cudagraph_mode, local_cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(
|
||||
num_reqs=len(scheduler_output.num_scheduled_tokens),
|
||||
num_tokens=scheduler_output.total_num_scheduled_tokens,
|
||||
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
|
||||
)
|
||||
)
|
||||
# Get batch descriptor and sync across DP ranks.
|
||||
num_reqs = len(scheduler_output.num_scheduled_tokens)
|
||||
num_toks = scheduler_output.total_num_scheduled_tokens
|
||||
max_query_len = max(scheduler_output.num_scheduled_tokens.values())
|
||||
uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
|
||||
|
||||
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
|
||||
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
local_cudagraph_size,
|
||||
local_cudagraph_mode.value,
|
||||
self.parallel_config.data_parallel_size,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
batch_desc = self.cudagraph_manager.dispatch(
|
||||
num_reqs, num_toks, uniform_tok_count
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if num_tokens_after_padding == 0:
|
||||
num_tokens_across_dp = None
|
||||
|
||||
if self.dp_size > 1:
|
||||
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
|
||||
self.cudagraph_manager,
|
||||
batch_desc,
|
||||
num_toks,
|
||||
num_reqs,
|
||||
uniform_tok_count,
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
|
||||
if batch_desc.num_tokens == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
return empty_output
|
||||
@@ -879,9 +901,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not dummy_run:
|
||||
# Common case.
|
||||
# Prepare all the inputs and copy to the input buffers.
|
||||
input_batch = self.prepare_inputs(
|
||||
scheduler_output, num_tokens_after_padding
|
||||
)
|
||||
input_batch = self.prepare_inputs(scheduler_output, batch_desc)
|
||||
block_tables, slot_mappings = self.prepare_attn(input_batch)
|
||||
|
||||
if self.lora_config:
|
||||
@@ -894,9 +914,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self._set_active_loras(*lora_inputs)
|
||||
else:
|
||||
# No actual tokens to run. A dummy run for DP or memory profiling.
|
||||
num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs, num_tokens_after_padding, self.input_buffers
|
||||
batch_desc.num_reqs or num_reqs,
|
||||
batch_desc.num_tokens,
|
||||
self.input_buffers,
|
||||
)
|
||||
if not skip_attn_for_dummy_run:
|
||||
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
|
||||
@@ -948,14 +969,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
model_inputs["intermediate_tensors"] = intermediate_tensors
|
||||
|
||||
# Run model.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
model_output = self.cudagraph_manager.run_fullgraph(
|
||||
input_batch.num_tokens_after_padding
|
||||
)
|
||||
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
@@ -972,7 +991,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
cudagraph_runtime_mode=batch_desc.cg_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings_by_layer,
|
||||
|
||||
@@ -142,12 +142,15 @@ class DefaultModelState(ModelState):
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> dict[str, Any]:
|
||||
# Use padded sizes - padding is handled by model_runner.prepare_attn.
|
||||
num_reqs = input_batch.num_reqs_after_padding
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
|
||||
max_query_len = input_batch.num_scheduled_tokens.max().item()
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_groups=attn_groups,
|
||||
num_reqs=input_batch.num_reqs,
|
||||
num_tokens=input_batch.num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
query_start_loc_gpu=input_batch.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=max_query_len,
|
||||
|
||||
@@ -1,181 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.model_executor.offloader.base import get_offloader
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
capture_graphs,
|
||||
get_cudagraph_sizes,
|
||||
BatchExecutionDescriptor,
|
||||
CudaGraphManager,
|
||||
prepare_inputs_to_capture,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.utils import AttentionGroup
|
||||
|
||||
|
||||
class EagleCudaGraphManager:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device = device
|
||||
class EagleCudaGraphManager(CudaGraphManager):
|
||||
"""CudaGraphManager for Eagle speculative decoding (FULL mode only)."""
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
|
||||
|
||||
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
|
||||
_, self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
uniform_decode_query_len=1,
|
||||
uniform_decode_cudagraph=True,
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
draft_tokens: torch.Tensor,
|
||||
):
|
||||
assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), (
|
||||
"EagleCudaGraphManager does not support PIECEWISE mode yet"
|
||||
)
|
||||
# Eagle always uses uniform decode with query_len=1
|
||||
super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1)
|
||||
self.draft_tokens = draft_tokens
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
# Use a dedicated pool for Eagle to avoid memory overlap with the main
|
||||
# model's cudagraph. The base class uses a shared global pool, but Eagle's
|
||||
# internal allocations (e.g., gumbel_sample temporaries) can conflict with
|
||||
# the main model's allocations when sharing the same pool.
|
||||
if cudagraph_mode:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_tokens: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode
|
||||
|
||||
if (
|
||||
cudagraph_mode == CUDAGraphMode.FULL
|
||||
and cudagraph_size is not None
|
||||
and cudagraph_size not in self.graphs
|
||||
):
|
||||
# If graph wasn't captured yet, fall back to eager.
|
||||
# This might happen when the dummy run is called before capture.
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
cudagraph_size = None
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
model_state: ModelState,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
# Capture the graph.
|
||||
capture_fn(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
generate_fn=generate_fn,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
|
||||
# Sync offloader's copy stream before capture.
|
||||
# Ensure any pre-capture prefetches from offloader are complete.
|
||||
get_offloader().sync_prev_onload()
|
||||
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
# Join offloader's copy stream after forward to avoid unjoined
|
||||
# stream error. The last layer's start_prefetch forks copy_stream,
|
||||
# but wait_prefetch only happens in the next forward pass.
|
||||
get_offloader().join_after_forward()
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
@@ -184,31 +50,42 @@ class EagleCudaGraphManager:
|
||||
block_tables: BlockTables,
|
||||
attn_groups: list[list[AttentionGroup]],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
progress_bar_desc: str = "Capturing CUDA graphs",
|
||||
) -> None:
|
||||
if self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return
|
||||
"""Capture CUDA graphs for Eagle speculative decoding (FULL mode only)."""
|
||||
|
||||
capture_graphs(
|
||||
self.cudagraph_sizes,
|
||||
self.device,
|
||||
self.capture_graph,
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
model_state=model_state,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_groups=attn_groups,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
def create_forward_fn(
|
||||
desc: BatchExecutionDescriptor,
|
||||
) -> Callable[[CUDAGraphMode], None]:
|
||||
num_tokens = desc.num_tokens
|
||||
num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs)
|
||||
num_tokens_across_dp = (
|
||||
torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu")
|
||||
if self.dp_size > 1
|
||||
else None
|
||||
)
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
model_state,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_groups,
|
||||
kv_cache_config,
|
||||
)
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> None:
|
||||
assert num_tokens in self.graphs
|
||||
# Sync offloader before replay - needed when transitioning from
|
||||
# eager/piecewise to full cudagraph (e.g., prefill → decode).
|
||||
# The previous eager iteration's start_prefetch may have queued
|
||||
# H2D copies on copy_stream that the graph's captured events
|
||||
# cannot see. Without this, replay could overwrite static buffers
|
||||
# while those copies are still in flight.
|
||||
get_offloader().sync_prev_onload()
|
||||
self.graphs[num_tokens].replay()
|
||||
return lambda cg_mode: generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
cg_mode,
|
||||
)
|
||||
|
||||
super().capture(create_forward_fn, progress_bar_desc)
|
||||
|
||||
def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor:
|
||||
"""Replay a captured FULL cudagraph and return draft tokens."""
|
||||
super().run_fullgraph(desc)
|
||||
return self.draft_tokens
|
||||
|
||||
@@ -16,7 +16,7 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_slot_mappings_by_layer,
|
||||
)
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.model_states.interface import ModelState
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
@@ -75,7 +75,16 @@ class EagleSpeculator:
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
|
||||
# currently we don't support PIECEWISE for Eagle.
|
||||
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
|
||||
cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
else:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
self.cudagraph_manager = EagleCudaGraphManager(
|
||||
vllm_config, device, cudagraph_mode, self.draft_tokens
|
||||
)
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
self.model = load_eagle_model(target_model, self.vllm_config)
|
||||
@@ -171,7 +180,7 @@ class EagleSpeculator:
|
||||
)
|
||||
if attn_metadata is not None:
|
||||
self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
idx_mapping, query_start_loc, pos, num_tokens_padded
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
@@ -185,6 +194,7 @@ class EagleSpeculator:
|
||||
self.block_tables,
|
||||
self.attn_groups,
|
||||
self.kv_cache_config,
|
||||
progress_bar_desc="Capturing eagle CUDA graphs",
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -251,6 +261,7 @@ class EagleSpeculator:
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
num_reqs_padded = input_batch.num_reqs_after_padding
|
||||
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||
# and ignore the other sampling parameters such as top_k and top_p,
|
||||
# for simplicity and performance.
|
||||
@@ -292,48 +303,52 @@ class EagleSpeculator:
|
||||
self.max_num_reqs,
|
||||
)
|
||||
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos
|
||||
)
|
||||
# Get batch descriptor and sync across DP ranks.
|
||||
# Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode
|
||||
|
||||
cudagraph_mode, cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs)
|
||||
)
|
||||
num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1)
|
||||
num_tokens_across_dp = None
|
||||
|
||||
if self.dp_size > 1:
|
||||
batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding(
|
||||
self.cudagraph_manager,
|
||||
batch_desc,
|
||||
num_reqs,
|
||||
cudagraph_size,
|
||||
cudagraph_mode.value,
|
||||
num_reqs,
|
||||
1, # uniform_token_count
|
||||
self.dp_size,
|
||||
self.dp_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(num_tokens_padded)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
idx_mapping, query_start_loc, pos, batch_desc.num_tokens
|
||||
)
|
||||
|
||||
if batch_desc.cg_mode == CUDAGraphMode.FULL:
|
||||
return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs]
|
||||
|
||||
# Run eager or piecewise CUDA graph.
|
||||
attn_metadata_updated = None
|
||||
slot_mappings_updated = None
|
||||
if not (dummy_run and skip_attn_for_dummy_run):
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
num_reqs_padded + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
block_tables = [
|
||||
x[:num_reqs_padded] for x in self.block_tables.input_block_tables
|
||||
]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata_updated = build_attn_metadata(
|
||||
attn_groups=self.attn_groups,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
num_reqs=num_reqs_padded,
|
||||
num_tokens=num_reqs_padded,
|
||||
query_start_loc_gpu=query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
max_query_len=1,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs_padded],
|
||||
max_seq_len=self.max_model_len,
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
@@ -345,11 +360,11 @@ class EagleSpeculator:
|
||||
|
||||
self.generate_draft(
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
batch_desc.num_tokens,
|
||||
attn_metadata_updated,
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
cudagraph_runtime_mode=batch_desc.cg_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user