[Model Runner V2] support piecewise & mixed cudagraph (#32771)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -11,7 +11,8 @@ from tqdm import tqdm
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
@@ -34,14 +35,27 @@ class CudaGraphManager:
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
|
||||
self.uniform_decode_query_len = 1
|
||||
spec_config = vllm_config.speculative_config
|
||||
if spec_config is not None:
|
||||
self.uniform_decode_query_len += spec_config.num_speculative_tokens
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
|
||||
use_uniform_decode_cudagraph = (
|
||||
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and self.cudagraph_mode.separate_routine()
|
||||
)
|
||||
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
self.uniform_decode_query_len,
|
||||
use_uniform_decode_cudagraph,
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
@@ -54,20 +68,16 @@ class CudaGraphManager:
|
||||
return len(self.cudagraph_sizes) > 0
|
||||
|
||||
def get_cudagraph_size(
|
||||
self,
|
||||
num_tokens_after_padding: int,
|
||||
num_tokens_per_request: Iterable[int],
|
||||
self, num_tokens: int, uniform_decode: bool = False
|
||||
) -> int | None:
|
||||
return get_cudagraph_size(
|
||||
num_tokens_after_padding,
|
||||
num_tokens_per_request,
|
||||
self.cudagraph_sizes,
|
||||
self.cudagraph_mode,
|
||||
)
|
||||
if uniform_decode and self.uniform_decode_cudagraph_sizes:
|
||||
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
model: nn.Module,
|
||||
input_buffers: InputBuffers,
|
||||
mrope_positions: torch.Tensor | None,
|
||||
@@ -75,8 +85,25 @@ class CudaGraphManager:
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
uniform_decode: bool = False,
|
||||
) -> None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
# select and check capture function
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
# prepare inputs
|
||||
if uniform_decode:
|
||||
num_reqs = min(
|
||||
cdiv(num_tokens, self.uniform_decode_query_len),
|
||||
self.max_num_reqs,
|
||||
)
|
||||
else:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_ids = input_buffers.input_ids[:num_tokens]
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
if self.uses_mrope:
|
||||
@@ -92,6 +119,9 @@ class CudaGraphManager:
|
||||
attn_metadata_builders,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=(
|
||||
self.uniform_decode_query_len if uniform_decode else 0
|
||||
),
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
@@ -112,13 +142,40 @@ class CudaGraphManager:
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
capture_fn(
|
||||
num_tokens=num_tokens,
|
||||
num_reqs=num_reqs,
|
||||
model=model,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
assert attn_metadata is not None
|
||||
# Capture the graph.
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with (
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
attn_metadata=attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
@@ -131,9 +188,44 @@ class CudaGraphManager:
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
assert self.hidden_states is not None
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
# create batch descriptor for piecewise cudagraph dispatch key
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)
|
||||
|
||||
# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
|
||||
with set_forward_context(
|
||||
attn_metadata=None, # piecewise no need attn_metadata
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=slot_mappings,
|
||||
):
|
||||
hidden_states = model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
assert self.hidden_states is not None
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
@@ -144,11 +236,11 @@ class CudaGraphManager:
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
has_lora: bool = False,
|
||||
) -> None:
|
||||
capture_graphs(
|
||||
self.cudagraph_sizes,
|
||||
self.device,
|
||||
self.capture_graph,
|
||||
common_kwargs = dict(
|
||||
device=self.device,
|
||||
capture_fn=self.capture_graph,
|
||||
model=model,
|
||||
input_buffers=input_buffers,
|
||||
mrope_positions=mrope_positions,
|
||||
@@ -156,10 +248,50 @@ class CudaGraphManager:
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
kv_cache_config=kv_cache_config,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
def run(self, num_tokens: int) -> torch.Tensor:
|
||||
assert num_tokens in self.graphs
|
||||
# Phase 1: Capture for mixed prefill-decode batches if needed.
|
||||
mixed_mode = self.cudagraph_mode.mixed_mode()
|
||||
if mixed_mode != CUDAGraphMode.NONE:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.cudagraph_sizes,
|
||||
capture_cudagraph_mode=mixed_mode,
|
||||
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
|
||||
uniform_decode=False,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
# Phase 2: Capture FULL graphs for uniform decode batches if needed.
|
||||
# This is only needed if we use a separate routine for decode batches
|
||||
# and the decode_mode is FULL.
|
||||
if self.uniform_decode_cudagraph_sizes:
|
||||
capture_graphs(
|
||||
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
|
||||
capture_cudagraph_mode=CUDAGraphMode.FULL,
|
||||
desc="Capturing CUDA graphs (decode, FULL)",
|
||||
uniform_decode=True,
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
def get_cudagraph_runtime_mode(
|
||||
self, num_reqs: int, num_tokens: int, max_query_len: int
|
||||
) -> tuple[CUDAGraphMode, int | None]:
|
||||
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
|
||||
num_tokens == max_query_len * num_reqs
|
||||
)
|
||||
|
||||
cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
|
||||
if cudagraph_size is None:
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
elif is_uniform_decode:
|
||||
cudagraph_mode = self.cudagraph_mode.decode_mode()
|
||||
else:
|
||||
cudagraph_mode = self.cudagraph_mode.mixed_mode()
|
||||
return cudagraph_mode, cudagraph_size
|
||||
|
||||
def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
|
||||
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
|
||||
self.graphs[num_tokens].replay()
|
||||
assert self.hidden_states is not None
|
||||
return self.hidden_states[:num_tokens]
|
||||
@@ -170,22 +302,18 @@ def get_cudagraph_sizes(
|
||||
max_num_reqs: int,
|
||||
max_num_tokens: int,
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
) -> dict[int, int]:
|
||||
if not cudagraph_mode.has_full_cudagraphs():
|
||||
return {}
|
||||
uniform_decode_query_len: int = 1,
|
||||
uniform_decode_cudagraph: bool = False,
|
||||
) -> tuple[dict[int, int], dict[int, int]]:
|
||||
# Support both FULL and PIECEWISE cudagraph modes
|
||||
if cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return {}, {}
|
||||
if not capture_sizes:
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
capture_sizes = sorted(capture_sizes)
|
||||
# Limit the capture sizes to the max number of requests or tokens.
|
||||
upper_bound = (
|
||||
max_num_reqs
|
||||
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
|
||||
else max_num_tokens
|
||||
)
|
||||
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
|
||||
if not capture_sizes:
|
||||
return {}
|
||||
return {}, {}
|
||||
|
||||
cudagraph_sizes: dict[int, int] = {}
|
||||
for i in range(1, capture_sizes[-1] + 1):
|
||||
@@ -193,45 +321,34 @@ def get_cudagraph_sizes(
|
||||
if i <= x:
|
||||
cudagraph_sizes[i] = x
|
||||
break
|
||||
return cudagraph_sizes
|
||||
|
||||
|
||||
def get_cudagraph_size(
|
||||
num_tokens_after_dp_padding: int,
|
||||
num_tokens_per_request: Iterable[int],
|
||||
cudagraph_sizes: dict[int, int],
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
) -> int | None:
|
||||
if not cudagraph_mode.has_full_cudagraphs():
|
||||
# No full CUDA graph is used.
|
||||
return None
|
||||
|
||||
size = cudagraph_sizes.get(num_tokens_after_dp_padding)
|
||||
if size is None:
|
||||
# No CUDA graph for this size.
|
||||
return None
|
||||
|
||||
is_mixed = any(x > 1 for x in num_tokens_per_request)
|
||||
if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL:
|
||||
# Prefill is included, and this mode doesn't use CUDA graph for it.
|
||||
return None
|
||||
return size
|
||||
uniform_decode_cudagraph_sizes: dict[int, int] = {}
|
||||
if uniform_decode_cudagraph:
|
||||
max_num_tokens = max_num_reqs * uniform_decode_query_len
|
||||
uniform_decode_cudagraph_sizes = {
|
||||
k: v
|
||||
for k, v in cudagraph_sizes.items()
|
||||
if v <= max_num_tokens and v >= uniform_decode_query_len
|
||||
}
|
||||
return cudagraph_sizes, uniform_decode_cudagraph_sizes
|
||||
|
||||
|
||||
def capture_graphs(
|
||||
cudagraph_sizes: dict[int, int],
|
||||
device: torch.device,
|
||||
capture_fn: Callable,
|
||||
capture_cudagraph_mode: CUDAGraphMode,
|
||||
desc: str = "Capturing CUDA graphs",
|
||||
**capture_kwargs,
|
||||
) -> None:
|
||||
# Capture larger graphs first.
|
||||
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
|
||||
if is_global_first_rank():
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
|
||||
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)
|
||||
|
||||
with graph_capture(device=device):
|
||||
for size in sizes_to_capture:
|
||||
capture_fn(size, **capture_kwargs)
|
||||
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)
|
||||
|
||||
|
||||
def prepare_inputs_to_capture(
|
||||
@@ -242,8 +359,12 @@ def prepare_inputs_to_capture(
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
max_model_len: int,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
uniform_decode_query_len: int = 0,
|
||||
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
if uniform_decode_query_len > 0:
|
||||
num_tokens_per_req = uniform_decode_query_len
|
||||
else:
|
||||
num_tokens_per_req = num_tokens // num_reqs
|
||||
|
||||
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
|
||||
query_start_loc_np[-1] = num_tokens
|
||||
|
||||
@@ -13,48 +13,65 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N
|
||||
|
||||
|
||||
def get_batch_metadata_across_dp(
|
||||
num_tokens: int, cudagraph_size: int, dp_size: int, dp_rank: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_tokens: int,
|
||||
cudagraph_size: int,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert dp_size > 1
|
||||
# Use CPU group to avoid CPU-GPU synchronization.
|
||||
group = get_dp_group().cpu_group
|
||||
tensor = torch.zeros(2, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu")
|
||||
tensor[0][dp_rank] = num_tokens
|
||||
tensor[1][dp_rank] = cudagraph_size
|
||||
tensor[2][dp_rank] = cudagraph_runtime_mode
|
||||
dist.all_reduce(tensor, group=group)
|
||||
return tensor[0], tensor[1]
|
||||
return tensor[0], tensor[1], tensor[2]
|
||||
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
num_tokens: int, cudagraph_size: int | None, dp_size: int, dp_rank: int
|
||||
) -> tuple[bool, int, torch.Tensor | None]:
|
||||
num_tokens: int,
|
||||
cudagraph_size: int | None,
|
||||
cudagraph_runtime_mode: int,
|
||||
dp_size: int,
|
||||
dp_rank: int,
|
||||
) -> tuple[int, torch.Tensor | None, int]:
|
||||
if dp_size == 1:
|
||||
if cudagraph_size is not None:
|
||||
return True, cudagraph_size, None
|
||||
return cudagraph_size, None, cudagraph_runtime_mode
|
||||
else:
|
||||
return False, num_tokens, None
|
||||
return num_tokens, None, cudagraph_runtime_mode
|
||||
|
||||
# Convert None to -1 for sync (indicates no cudagraph available)
|
||||
if num_tokens == 0:
|
||||
cudagraph_size = 0
|
||||
elif cudagraph_size is None:
|
||||
cudagraph_size = -1
|
||||
num_tokens_across_dp, cudagraph_size_across_dp = get_batch_metadata_across_dp(
|
||||
num_tokens, cudagraph_size, dp_size, dp_rank
|
||||
|
||||
num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = (
|
||||
get_batch_metadata_across_dp(
|
||||
num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank
|
||||
)
|
||||
)
|
||||
if torch.all(num_tokens_across_dp == 0).item():
|
||||
# All ranks have zero tokens to run.
|
||||
return False, 0, None
|
||||
return 0, None, 0
|
||||
|
||||
if torch.all(cudagraph_size_across_dp != -1).item():
|
||||
# All ranks use CUDA graph or have zero tokens.
|
||||
# Use CUDA graph for all ranks.
|
||||
# Pad all ranks to the maximum CUDA graph size.
|
||||
# Synchronize cudagraph_runtime_mode across ranks by taking the minimum.
|
||||
synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item())
|
||||
# Check if all ranks have valid cudagraph_size.
|
||||
all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item()
|
||||
|
||||
if synced_cudagraph_mode != 0 and all_have_cudagraph:
|
||||
# All ranks use cudagraph. Pad to max cudagraph_size.
|
||||
max_cudagraph_size = int(cudagraph_size_across_dp.max().item())
|
||||
num_tokens_across_dp[:] = max_cudagraph_size
|
||||
return True, max_cudagraph_size, num_tokens_across_dp
|
||||
return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode
|
||||
else:
|
||||
# Some ranks do not use CUDA graph. Use eager mode for all ranks.
|
||||
# No padding is needed except for ranks that have no tokens to run.
|
||||
# Fall back to eager mode (no cudagraph).
|
||||
# Either some rank doesn't have cudagraph size or mode is NONE.
|
||||
synced_cudagraph_mode = 0
|
||||
num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1)
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item())
|
||||
return False, num_tokens_after_padding, num_tokens_across_dp
|
||||
return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_pp_group,
|
||||
prepare_communication_buffer_for_model,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -140,7 +140,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.do_spec_decode = False
|
||||
self.num_speculative_steps = 0
|
||||
self.speculator = None
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
@@ -458,6 +457,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_tables=self.block_tables,
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
self.speculator.capture_model()
|
||||
@@ -884,19 +884,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
return empty_output
|
||||
|
||||
# Get the CUDA graph size. None means no CUDA graph is used.
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
scheduler_output.num_scheduled_tokens.values(),
|
||||
# Get local cudagraph mode and size.
|
||||
local_cudagraph_mode, local_cudagraph_size = (
|
||||
self.cudagraph_manager.get_cudagraph_runtime_mode(
|
||||
num_reqs=len(scheduler_output.num_scheduled_tokens),
|
||||
num_tokens=scheduler_output.total_num_scheduled_tokens,
|
||||
max_query_len=max(scheduler_output.num_scheduled_tokens.values()),
|
||||
)
|
||||
)
|
||||
use_cudagraph, num_tokens_after_padding, num_tokens_across_dp = (
|
||||
|
||||
# DP sync: num_tokens + cudagraph_size + cudagraph_mode
|
||||
num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = (
|
||||
get_cudagraph_and_dp_padding(
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
cudagraph_size,
|
||||
local_cudagraph_size,
|
||||
local_cudagraph_mode.value,
|
||||
self.parallel_config.data_parallel_size,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
)
|
||||
cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode)
|
||||
if num_tokens_after_padding == 0:
|
||||
# All DP ranks have zero tokens to run.
|
||||
empty_output = self.kv_connector.no_forward(scheduler_output)
|
||||
@@ -946,16 +953,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# FIXME(woosuk): Fix warmup for LoRA.
|
||||
|
||||
# Run model.
|
||||
if use_cudagraph:
|
||||
# Run CUDA graph.
|
||||
if cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
# Use explicit cudagraph replay for FULL mode.
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.cudagraph_manager.run(
|
||||
hidden_states = self.cudagraph_manager.run_fullgraph(
|
||||
input_batch.num_tokens_after_padding
|
||||
)
|
||||
else:
|
||||
# Run PyTorch model in eager mode.
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
if self.uses_mrope:
|
||||
assert input_batch.mrope_positions is not None
|
||||
@@ -970,13 +977,18 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds = None
|
||||
assert intermediate_tensors is not None
|
||||
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
|
||||
with set_forward_context(
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
# TODO(woosuk): Support piecewise CUDA graph.
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
batch_descriptor=batch_descriptor,
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -103,14 +103,17 @@ class EagleSpeculator:
|
||||
attn_metadata: dict[str, Any] | None,
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_descriptor = BatchDescriptor(num_tokens=num_tokens)
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
slot_mapping=slot_mappings,
|
||||
batch_descriptor=batch_descriptor,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_buffers.input_ids[:num_tokens],
|
||||
@@ -127,9 +130,11 @@ class EagleSpeculator:
|
||||
def generate_draft(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens_padded: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
@@ -137,8 +142,14 @@ class EagleSpeculator:
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
cudagraph_runtime_mode,
|
||||
)
|
||||
last_hidden_states = last_hidden_states[:num_reqs]
|
||||
hidden_states = hidden_states[:num_reqs]
|
||||
logits = self.model.compute_logits(last_hidden_states)
|
||||
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
@@ -283,12 +294,14 @@ class EagleSpeculator:
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
if cudagraph_size is not None:
|
||||
# Run CUDA graph.
|
||||
self.cudagraph_manager.run(cudagraph_size)
|
||||
cudagraph_mode = self.cudagraph_manager.cudagraph_mode
|
||||
if cudagraph_size is not None and cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# Run full CUDA graph.
|
||||
self.cudagraph_manager.run_fullgraph(cudagraph_size)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager mode.
|
||||
# Run eager or piecewise CUDA graph.
|
||||
num_tokens_padded = cudagraph_size if cudagraph_size is not None else num_reqs
|
||||
query_start_loc_cpu = torch.arange(
|
||||
num_reqs + 1, dtype=torch.int32, device="cpu"
|
||||
)
|
||||
@@ -312,8 +325,13 @@ class EagleSpeculator:
|
||||
slot_mappings, self.kv_cache_config
|
||||
)
|
||||
self.generate_draft(
|
||||
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None
|
||||
) # FIXME
|
||||
num_reqs,
|
||||
num_tokens_padded,
|
||||
attn_metadata,
|
||||
slot_mappings_by_layer,
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -31,16 +32,17 @@ class EagleCudaGraphManager:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode()
|
||||
|
||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
# only need to capture uniform decode cudagraph sizes (the 2nd return value)
|
||||
_, self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
uniform_decode_query_len=1,
|
||||
uniform_decode_cudagraph=True,
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
@@ -54,12 +56,21 @@ class EagleCudaGraphManager:
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
capture_cg_mode: CUDAGraphMode,
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
|
||||
)
|
||||
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
|
||||
capture_fn = self._capture_piecewise_graph
|
||||
else:
|
||||
capture_fn = self._capture_full_graph
|
||||
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
attn_metadata, slot_mappings = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
@@ -69,19 +80,70 @@ class EagleCudaGraphManager:
|
||||
attn_metadata_builders,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
uniform_decode_query_len=1,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
# Capture the graph.
|
||||
capture_fn(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
generate_fn=generate_fn,
|
||||
attn_metadata=attn_metadata,
|
||||
slot_mappings=slot_mappings,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
|
||||
def _capture_full_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
def _capture_piecewise_graph(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
attn_metadata: dict[str, Any],
|
||||
slot_mappings: dict[str, torch.Tensor],
|
||||
num_tokens_across_dp: torch.Tensor,
|
||||
) -> None:
|
||||
generate_fn(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
attn_metadata,
|
||||
slot_mappings,
|
||||
num_tokens_across_dp,
|
||||
CUDAGraphMode.PIECEWISE,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
@@ -91,10 +153,15 @@ class EagleCudaGraphManager:
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
if self.cudagraph_mode == CUDAGraphMode.NONE:
|
||||
return
|
||||
|
||||
capture_graphs(
|
||||
self.cudagraph_sizes,
|
||||
self.device,
|
||||
self.capture_graph,
|
||||
capture_cudagraph_mode=self.cudagraph_mode,
|
||||
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
|
||||
generate_fn=generate_fn,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
@@ -102,6 +169,6 @@ class EagleCudaGraphManager:
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
def run(self, num_tokens: int) -> None:
|
||||
def run_fullgraph(self, num_tokens: int) -> None:
|
||||
assert num_tokens in self.graphs
|
||||
self.graphs[num_tokens].replay()
|
||||
|
||||
Reference in New Issue
Block a user