[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)

This commit is contained in:
PatchyTIS
2026-03-08 05:51:37 +08:00
committed by GitHub
parent ee54f9cdb9
commit a6be75dbd2
9 changed files with 940 additions and 12 deletions

View File

@@ -98,7 +98,7 @@ def test_without_spec_decoding(
@single_gpu_only
@large_gpu_mark(min_gb=16)
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
"""Test consistency and acceptance rates with some different combos of
preemption, executor, async scheduling, prefill chunking,
spec decoding model length.
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
)
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
"""Test ngram_gpu speculative decoding with different configurations.
This test specifically validates ngram_gpu behavior with various:
- Number of speculative tokens (2-6)
- Prompt lookup window sizes (min/max)
- Async scheduling enabled (as in production)
- Different executors and chunking settings
"""
# Variant with larger speculation window
ngram_gpu_config = {
"method": "ngram_gpu",
"num_speculative_tokens": 3,
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
}
# Test configurations covering various scenarios
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
(False, "mp", False, None, False),
(False, "mp", False, ngram_gpu_config, False),
(True, "mp", False, ngram_gpu_config, True),
(False, "mp", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, False),
(True, "uni", True, ngram_gpu_config, False),
(True, "mp", True, ngram_gpu_config, True),
]
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
# and ngram_gpu doesn't require a specific draft model
run_tests(monkeypatch, MODEL, test_configs, [{}])
@dynamo_config.patch(cache_size_limit=16)
def run_tests(
monkeypatch: pytest.MonkeyPatch,
@@ -282,11 +318,12 @@ def run_test(
else dict(gpu_memory_utilization=0.9)
)
spec_mml = (spec_config or {}).get("max_model_len")
spec_method = (spec_config or {}).get("method", "none")
test_config = (
f"executor={executor}, preemption={test_preemption}, "
f"async_sched={async_scheduling}, "
f"chunk_prefill={test_prefill_chunking}, "
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
)
print("-" * 80)
print(f"---- TESTING {test_str}: {test_config}")
@@ -294,7 +331,7 @@ def run_test(
with VllmRunner(
model,
max_model_len=512,
max_model_len=4096,
enable_chunked_prefill=test_prefill_chunking,
# Force prefill chunking
max_num_batched_tokens=48 if test_prefill_chunking else None,

View File

@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("async_scheduling", [True], ids=["async"])
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_ngram_gpu_default_with_async_scheduling(
async_scheduling: bool,
):
"""
Test ngram_gpu speculative decoding (k=3) correctness with and without
async scheduling, validated via GSM8K accuracy.
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
"""
qwen3_model = "Qwen/Qwen3-8B"
spec_llm = LLM(
model=qwen3_model,
speculative_config={
"method": "ngram_gpu",
"prompt_lookup_max": 3,
"prompt_lookup_min": 2,
"num_speculative_tokens": 2,
},
max_model_len=4096,
async_scheduling=async_scheduling,
)
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
del spec_llm
cleanup_dist_env_and_memory()
@single_gpu_only
@large_gpu_mark(min_gb=20)
def test_suffix_decoding_acceptance(

View File

@@ -907,6 +907,13 @@ class VllmBackend:
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
disable_cache = not is_compile_cache_enabled(self.inductor_config)
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
is_ngram_gpu_enabled = (
vllm_config.speculative_config is not None
and vllm_config.speculative_config.use_ngram_gpu()
)
disable_cache = disable_cache or is_ngram_gpu_enabled
if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
else:

View File

@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
"step3p5_mtp",
]
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
NgramGPUTypes = Literal["ngram_gpu"]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
"draft_model",
"suffix",
EagleModelTypes,
NgramGPUTypes,
]
@@ -364,6 +366,8 @@ class SpeculativeConfig:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "ngram_gpu":
self.model = "ngram_gpu"
elif self.method == "suffix":
self.model = "suffix"
elif self.method == "extract_hidden_states":
@@ -374,8 +378,9 @@ class SpeculativeConfig:
)
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
if self.method in ("ngram", "ngram_gpu"):
# Set default values if not provided
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
@@ -832,6 +837,9 @@ class SpeculativeConfig:
def uses_extract_hidden_states(self) -> bool:
return self.method == "extract_hidden_states"
def use_ngram_gpu(self) -> bool:
return self.method == "ngram_gpu"
def __repr__(self) -> str:
method = self.method
model = (

View File

@@ -41,7 +41,7 @@ from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
from .speculative import EagleModelTypes, SpeculativeConfig
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
from .structured_outputs import StructuredOutputsConfig
from .utils import SupportsHash, config, replace
from .weight_transfer import WeightTransferConfig
@@ -696,11 +696,13 @@ class VllmConfig:
if self.speculative_config is not None:
if (
self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
and self.speculative_config.method != "draft_model"
):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP/Draft Model kind of speculative decoding."
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
"speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
@@ -718,6 +720,7 @@ class VllmConfig:
if (
self.speculative_config is not None
and self.speculative_config.method not in get_args(EagleModelTypes)
and self.speculative_config.method not in get_args(NgramGPUTypes)
):
logger.warning_once(
"Async scheduling not supported with %s-based "

View File

@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
assert current_tool_call is not None
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
assert isinstance(current_tool_call, dict)
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:

View File

@@ -0,0 +1,660 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
This version uses a fully vectorized approach with unfold and argmax for
finding the first match across all sequences in parallel.
"""
import torch
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationMode,
CUDAGraphMode,
VllmConfig,
)
from vllm.forward_context import set_forward_context
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@support_torch_compile()
class NgramGPUKernel(nn.Module):
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
def __init__(
self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
):
super().__init__()
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
def _find_first_and_extract_all_n_parallel(
self,
token_ids: torch.Tensor,
seq_lengths: torch.Tensor,
min_ngram_len: int,
max_ngram_len: int,
num_draft_tokens: int,
) -> torch.Tensor:
"""
Find suffix n-gram matches and extract following tokens.
Searches for the earliest prior occurrence of the trailing n-gram,
tries multiple lengths, and picks the longest valid match.
Args:
token_ids: Token IDs for each sequence
seq_lengths: Actual length of each sequence (excluding padding)
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
num_draft_tokens: Number of tokens to extract after match (k)
Returns:
Draft token predictions; -1 means invalid/no match.
"""
batch_size = token_ids.shape[0]
max_seq_len = token_ids.shape[1]
device = token_ids.device
num_ngram_sizes = max_ngram_len - min_ngram_len + 1
# All n-gram sizes to try.
ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
batch_indices = torch.arange(batch_size, device=device)
# Earliest match per (sequence, ngram_len); -1 means no match.
first_match_positions = torch.full(
(batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
)
for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
# Sliding windows of size ngram_len; unfold is O(1) view.
search_windows = token_ids.unfold(1, ngram_len, 1)
num_windows = search_windows.shape[1]
# Trailing suffix (last ngram_len tokens) for each sequence.
suffix_starts = seq_lengths - ngram_len
suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
ngram_len, device=device
)
suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
# Window matches for each sequence.
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
# Match must leave room for at least one draft token.
max_valid_suffix_start = seq_lengths - ngram_len - 1
window_positions = torch.arange(num_windows, device=device)
valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
final_matches = matches & valid_mask
# Find earliest match (argmax=0 when empty; verify with has_match).
first_match_idx = torch.argmax(final_matches.int(), dim=1)
has_match = final_matches[batch_indices, first_match_idx]
# Store valid match positions (window index = position).
first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)
# Select the longest n-gram with a match.
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back
# Match position for the best n-gram.
best_match_pos = first_match_positions[batch_indices, best_ngram_idx]
# Avoid data-dependent branching.
has_any_match = best_match_pos >= 0
# Length of the best matching n-gram.
best_ngram_lengths = ngram_lengths[best_ngram_idx]
# Start position right after the matched suffix.
draft_start = torch.where(
has_any_match,
best_match_pos + best_ngram_lengths,
torch.zeros_like(best_match_pos),
)
tokens_available = seq_lengths - draft_start
# Gather indices for draft tokens.
draft_indices = draft_start.unsqueeze(1) + torch.arange(
num_draft_tokens, device=device
)
draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)
# Extract draft tokens; gather always runs.
draft_tokens = torch.gather(token_ids, 1, draft_indices)
# Mask positions beyond available tokens.
position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
valid_positions = position_indices < tokens_available.unsqueeze(1)
draft_tokens = torch.where(
valid_positions,
draft_tokens,
torch.full_like(draft_tokens, -1),
)
# If no match, mask all positions.
draft_tokens = torch.where(
has_any_match.unsqueeze(1),
draft_tokens,
torch.full_like(draft_tokens, -1),
)
return draft_tokens
def forward(
self,
num_tokens_no_spec: torch.Tensor,
token_ids_gpu: torch.Tensor,
combined_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for N-gram proposal using GPU tensor operations.
Args:
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
token_ids_gpu: Token IDs [batch_size, max_len]
combined_mask: Whether each sequence is valid for spec decode [batch_size]
Returns:
draft_tokens: [batch_size, k] on GPU
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
leading valid (non -1) tokens per request.
"""
device = token_ids_gpu.device
# Infer batch size to preserve dynamic shape.
actual_batch_size = token_ids_gpu.shape[0]
# Allocate in forward so torch.compile can optimize.
# NOTE(patchy): Do NOT pre-allocate this as a buffer
# it breaks torch.compile
draft_tokens = torch.full(
(actual_batch_size, self.k), -1, dtype=torch.int32, device=device
)
results = self._find_first_and_extract_all_n_parallel(
token_ids_gpu,
num_tokens_no_spec,
min_ngram_len=self.min_n,
max_ngram_len=self.max_n,
num_draft_tokens=self.k,
)
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
# Count leading contiguous valid (non -1) tokens per request.
is_valid = draft_tokens != -1 # [batch, k]
cum_valid = is_valid.int().cumsum(dim=1) # [batch, k]
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
return draft_tokens, num_valid_draft_tokens
def load_model(self, *args, **kwargs):
"""No model to load for N-gram proposer."""
pass
class NgramProposerGPU:
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
assert vllm_config.speculative_config is not None
assert vllm_config.speculative_config.prompt_lookup_min is not None
assert vllm_config.speculative_config.prompt_lookup_max is not None
compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=["none"],
splitting_ops=[],
compile_sizes=[],
inductor_compile_config={
"enable_auto_functionalized_v2": False,
"max_autotune": True,
"aggressive_fusion": True,
"triton.autotune_pointwise": True,
"coordinate_descent_tuning": True,
"use_mixed_mm": False,
},
cudagraph_mode=CUDAGraphMode.NONE,
)
model_config = vllm_config.model_config
speculative_config = vllm_config.speculative_config
scheduler_config = vllm_config.scheduler_config
self.vllm_config = VllmConfig(
compilation_config=compilation_config,
model_config=model_config,
speculative_config=speculative_config,
scheduler_config=scheduler_config,
)
self.min_n = vllm_config.speculative_config.prompt_lookup_min
self.max_n = vllm_config.speculative_config.prompt_lookup_max
self.k = vllm_config.speculative_config.num_speculative_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.device = device
self.kernel = NgramGPUKernel(
vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
)
self.kernel.to(device)
self.kernel.eval()
self._dummy_run()
def _dummy_run(self):
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
batch_size=self.max_num_seqs,
max_seq_len=self.max_model_len,
pattern_len=self.k,
device=self.device,
)
combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)
for _ in range(3):
with set_forward_context(None, self.vllm_config):
_, _ = self.kernel(num_tokens, token_ids, combined_mask)
def _generate_dummy_data(
self,
batch_size: int,
max_seq_len: int,
pattern_len: int,
device: str = "cuda",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Generate random test data with n-gram repetitions.
Args:
batch_size: Number of sequences in the batch
max_seq_len: Maximum sequence length
pattern_len: Length of patterns to inject for matching
device: Device to place tensors on
Returns:
token_ids: [batch_size, max_seq_len] tensor
num_tokens: [batch_size] tensor
sampled_flags: [batch_size] bool tensor
valid_mask: [batch_size] bool tensor
"""
token_ids = torch.zeros(
batch_size,
max_seq_len,
dtype=torch.int32,
device=device,
)
num_tokens = torch.randint(
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
)
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
return token_ids, num_tokens, sampled_flags, valid_mask
def propose(
self,
num_tokens_no_spec: torch.Tensor, # [batch_size]
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Propose draft tokens using GPU-accelerated n-gram matching.
Scatter sampled tokens into `token_ids_gpu`, compute temporary
updated lengths, then run the kernel.
Args:
num_tokens_no_spec: Number of tokens per sequence (read-only)
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
valid_sampled_tokens_count: Count of valid tokens per sequence
Returns:
draft_tokens: Proposed draft token IDs [batch_size, k]
num_valid_draft_tokens: Count of leading valid draft tokens
per request [batch_size]
"""
assert token_ids_gpu.device == self.device
assert num_tokens_no_spec.device == self.device
batch_size = num_tokens_no_spec.shape[0]
max_seq_len = token_ids_gpu.shape[1]
max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1
# Scatter newly sampled tokens into token_ids_gpu.
offsets = torch.arange(max_new_tokens, device=self.device)
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
1
)
in_bounds = write_positions < max_seq_len
scatter_mask = (
valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
)
write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
existing_values = token_ids_gpu.gather(1, write_positions_long)
tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
tokens_to_scatter = torch.where(
scatter_mask,
tokens_cast,
existing_values,
)
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
# Compute validity masks.
sampled_flags = valid_sampled_tokens_count > 0
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
with set_forward_context(None, self.vllm_config):
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
draft_tokens, num_valid_draft_tokens = self.kernel(
num_tokens_tmp,
token_ids_gpu,
combined_mask,
)
return draft_tokens, num_valid_draft_tokens
def update_token_ids_ngram(
self,
sampled_token_ids: torch.Tensor | list[list[int]],
gpu_input_batch: InputBatch,
token_ids_gpu: torch.Tensor,
num_tokens_no_spec: torch.Tensor,
discard_request_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare speculative decoding inputs on device:
compute next token ids and valid counts, honoring discarded requests
and rejected tokens, without CPU-GPU sync.
"""
num_reqs = gpu_input_batch.num_reqs
if isinstance(sampled_token_ids, list):
# When disable_padded_drafter_batch=True, sampled_token_ids is
# an irregular list[list[int]] where sublists may have different
# lengths (including empty lists for discarded requests).
# Pad all sublists to the same length with -1 before converting
# to tensor.
max_len = max(
(len(sublist) for sublist in sampled_token_ids),
default=0,
)
# Ensure at least length 1 for tensor creation
max_len = max(max_len, 1)
padded_list = [
sublist + [-1] * (max_len - len(sublist))
for sublist in sampled_token_ids
]
sampled_token_ids = torch.tensor(
padded_list, dtype=torch.int32, device=self.device
)
assert isinstance(sampled_token_ids, torch.Tensor), (
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
)
# Backup last valid token before speculative tokens.
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
backup_next_token_ids = torch.gather(
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
).squeeze(1)
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# Invalidate sampled tokens for discarded requests.
discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)
# Mask valid tokens within each request.
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count valid tokens per request.
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Rightmost valid index per row.
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Last valid token from each row; undefined if none.
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid; otherwise fallback to backup.
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
backup_next_token_ids,
)
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
def load_model(self, *args, **kwargs):
self.kernel.load_model(*args, **kwargs)
def update_scheduler_for_invalid_drafts(
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens_cpu: torch.Tensor,
scheduler_output: "SchedulerOutput",
req_id_to_index: dict[str, int],
) -> None:
"""Trim invalid speculative slots using per-request valid draft counts.
Args:
num_valid_draft_tokens_event: Event for async D2H completion.
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
scheduler_output: Scheduler metadata to update in-place.
req_id_to_index: Request-id to batch-index mapping.
"""
req_data = scheduler_output.scheduled_cached_reqs
num_valid_draft_tokens_event.synchronize()
for req_id in req_data.req_ids:
req_index = req_id_to_index.get(req_id)
if req_index is None:
continue
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
if spec_token_ids is None:
continue
scheduled_k = len(spec_token_ids)
valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
valid_k = max(0, min(valid_k, scheduled_k))
tokens_to_trim = scheduled_k - valid_k
scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim
if valid_k == 0:
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
else:
scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
:valid_k
]
def update_ngram_gpu_tensors_incremental(
input_batch: InputBatch,
token_ids_gpu_tensor: torch.Tensor,
num_tokens_no_spec_gpu: torch.Tensor,
new_reqs: list[CachedRequestState],
device: torch.device,
_pinned_idx_buf: torch.Tensor,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
for ngram GPU proposer.
"""
prev_req_id_to_index = input_batch.prev_req_id_to_index
curr_req_id_to_index = input_batch.req_id_to_index
if not curr_req_id_to_index:
return
active_indices = list(curr_req_id_to_index.values())
n_active = len(active_indices)
# Use resident pinned buffers to avoid per-call allocation.
active_idx_cpu = _pinned_idx_buf[:n_active]
active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))
active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)
new_req_ids = {req.req_id for req in new_reqs}
# First run, no previous state.
if prev_req_id_to_index is None:
for idx in active_indices:
num_tokens = input_batch.num_tokens_no_spec[idx]
if num_tokens > 0:
token_ids_gpu_tensor[idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[idx, :num_tokens],
non_blocking=True,
)
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
return
# Detect index changes for reorder.
reorder_src: list[int] = []
reorder_dst: list[int] = []
for req_id, curr_idx in curr_req_id_to_index.items():
if req_id in new_req_ids:
continue
prev_idx = prev_req_id_to_index.get(req_id)
if prev_idx is not None and prev_idx != curr_idx:
reorder_src.append(prev_idx)
reorder_dst.append(curr_idx)
if reorder_src:
src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)
temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()
token_ids_gpu_tensor[dst_tensor] = temp_token_ids
num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens
# Full copy for new/resumed requests.
for req_state in new_reqs:
new_req_idx = curr_req_id_to_index.get(req_state.req_id)
if new_req_idx is None:
continue
num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
if num_tokens > 0:
token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
non_blocking=True,
)
# Always batch-sync sequence lengths from CPU for ALL active requests.
_sync_num_tokens(
input_batch,
num_tokens_no_spec_gpu,
active_idx_cpu,
active_idx_gpu,
n_active,
device,
_pinned_val_buf,
)
def _sync_num_tokens(
input_batch: InputBatch,
num_tokens_no_spec_gpu: torch.Tensor,
active_idx_cpu: torch.Tensor,
active_idx_gpu: torch.Tensor,
n_active: int,
device: torch.device,
_pinned_val_buf: torch.Tensor,
) -> None:
"""Batch-sync GPU sequence lengths from CPU source of truth.
Inputs:
input_batch: Batch container with CPU length tensor.
num_tokens_no_spec_gpu: Destination GPU length tensor.
active_idx_cpu: Active request indices on CPU.
active_idx_gpu: Active request indices on GPU.
n_active: Number of active requests.
device: Target CUDA device.
_pinned_val_buf: Resident pinned int32 staging buffer.
Outputs:
None (updates num_tokens_no_spec_gpu in-place).
"""
src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
vals = _pinned_val_buf[:n_active]
vals.copy_(src_cpu.index_select(0, active_idx_cpu))
num_tokens_no_spec_gpu.index_copy_(
0,
active_idx_gpu,
vals.to(device=device, non_blocking=True),
)
def copy_num_valid_draft_tokens(
num_valid_draft_tokens_cpu: torch.Tensor,
num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
num_valid_draft_tokens_event: torch.cuda.Event,
num_valid_draft_tokens: torch.Tensor | None,
batch_size: int,
) -> None:
"""
Async D2H copy of per-request valid draft counts.
"""
if num_valid_draft_tokens is None:
return
num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
if num_reqs_to_copy <= 0:
return
default_stream = torch.cuda.current_stream()
with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
)
num_valid_draft_tokens_event.record()

View File

@@ -127,7 +127,13 @@ class InputBatch:
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec_cpu_tensor = torch.zeros(
(max_num_reqs,),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs,),

View File

@@ -10,7 +10,7 @@ from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import reduce
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer_gpu import (
NgramProposerGPU,
copy_num_valid_draft_tokens,
update_ngram_gpu_tensors_incremental,
update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
@@ -424,7 +430,7 @@ class GPUModelRunner(
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# TODO: Support overlapping micro-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend == "external_launcher"
@@ -493,6 +499,7 @@ class GPUModelRunner(
if self.speculative_config and get_pp_group().is_last_rank:
self.drafter: (
NgramProposer # noqa: F823
| NgramProposerGPU
| SuffixDecodingProposer
| EagleProposer
| DraftModelProposer
@@ -509,6 +516,23 @@ class GPUModelRunner(
device=self.device,
runner=self,
)
elif self.speculative_config.use_ngram_gpu():
self.drafter = NgramProposerGPU(self.vllm_config, self.device, self)
self.num_tokens_no_spec_gpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)
self.token_ids_gpu_tensor = torch.zeros(
self.max_num_reqs,
self.max_model_len,
dtype=torch.int32,
device=device,
)
self._ngram_pinned_idx_buf = torch.zeros(
self.max_num_reqs, dtype=torch.long, pin_memory=True
)
self._ngram_pinned_val_buf = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=True
)
elif self.speculative_config.method == "suffix":
self.drafter = SuffixDecodingProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
@@ -564,7 +588,7 @@ class GPUModelRunner(
)
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer
# We need to use the encoder length for encoder-decoder
# because of KV cache for cross-attention.
max_model_len=max(self.max_model_len, self.max_encoder_len),
max_num_batched_tokens=self.max_num_tokens,
@@ -721,6 +745,21 @@ class GPUModelRunner(
# Cached outputs.
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
self._num_valid_draft_tokens: torch.Tensor | None = None
self._num_valid_draft_tokens_cpu: torch.Tensor | None = None
self._num_valid_draft_tokens_event: torch.cuda.Event | None = None
self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
self._num_valid_draft_tokens_cpu = torch.empty(
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
)
self._num_valid_draft_tokens_event = torch.cuda.Event()
self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream()
self._draft_token_req_ids: list[str] | None = None
self.transfer_event = torch.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
@@ -992,6 +1031,13 @@ class GPUModelRunner(
for req_id in unscheduled_req_ids:
self.input_batch.remove_request(req_id)
is_ngram_gpu = (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
)
if is_ngram_gpu:
ngram_gpu_new_reqs: list[CachedRequestState] = []
reqs_to_add: list[CachedRequestState] = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
self._init_xdrope_positions(req_state)
reqs_to_add.append(req_state)
# Track new requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
# Save scheduler-allocated spec lengths before trimming so
# prev_num_draft_len keeps the optimistic count for rejection correction.
original_num_spec_per_req: dict[str, int] = {}
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
for req_id, toks in scheduled_spec_tokens.items():
original_num_spec_per_req[req_id] = len(toks)
update_scheduler_for_invalid_drafts(
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens_cpu,
scheduler_output,
self.input_batch.req_id_to_index,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# first step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt length),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# num_computed_tokens in first step and second step doesn't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
reqs_to_add.append(req_state)
# Track resumed requests for ngram_gpu full tensor copy
if is_ngram_gpu:
ngram_gpu_new_reqs.append(req_state)
continue
# Update the persistent batch.
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
# Add spec_token_ids to token_ids_cpu.
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
# Restore scheduler-side draft count after ngram trimming.
if original_num_spec_per_req:
orig = original_num_spec_per_req.get(req_id, 0)
if orig != req_state.prev_num_draft_len:
req_state.prev_num_draft_len = orig
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
# Refresh batch metadata with any pending updates.
self.input_batch.refresh_metadata()
# Incrementally update ngram_gpu tensors after batch is stable
if is_ngram_gpu:
update_ngram_gpu_tensors_incremental(
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
ngram_gpu_new_reqs,
self.device,
_pinned_idx_buf=self._ngram_pinned_idx_buf,
_pinned_val_buf=self._ngram_pinned_val_buf,
)
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None:
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
else:
logger.error("RoutedExpertsCapturer not initialized.")
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
# the modification has influence on the scheduler_output in engine core process.
# The replace is much faster than deepcopy.
if (
self.speculative_config is not None
and self.speculative_config.use_ngram_gpu()
):
num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
spec_decode_tokens_copy = (
scheduler_output.scheduled_spec_decode_tokens.copy()
)
scheduler_output = replace(
scheduler_output,
num_scheduled_tokens=num_scheduled_tokens_copy,
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
)
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
self._draft_token_ids = torch.zeros(
1, device=self.device, dtype=torch.int32
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
elif (
spec_config.use_ngram_gpu()
and not spec_config.disable_padded_drafter_batch
):
assert isinstance(self.drafter, NgramProposerGPU)
sampled_token_ids = sampler_output.sampled_token_ids
if input_fits_in_drafter:
propose_draft_token_ids(sampled_token_ids)
elif self.valid_sampled_token_count_event is not None:
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count, _ = (
self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
# Since we couldn't run the drafter,
# just use zeros for the draft tokens.
self._draft_token_ids = torch.zeros(
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
)
if isinstance(self.drafter, NgramProposer):
assert isinstance(sampled_token_ids, list), (
"sampled_token_ids should be a python list when ngram is used."
)
draft_token_ids = self.drafter.propose(
sampled_token_ids,
self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu,
)
elif spec_config.use_ngram_gpu():
assert isinstance(self.drafter, NgramProposerGPU)
(
next_token_ids,
valid_sampled_tokens_count,
valid_sampled_token_ids_gpu,
) = self.drafter.update_token_ids_ngram(
sampled_token_ids,
self.input_batch,
self.token_ids_gpu_tensor,
self.num_tokens_no_spec_gpu,
self.discard_request_mask.gpu,
)
self._copy_valid_sampled_token_count(
next_token_ids, valid_sampled_tokens_count
)
batch_size = next_token_ids.shape[0]
draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
self.num_tokens_no_spec_gpu[:batch_size],
self.token_ids_gpu_tensor[:batch_size],
valid_sampled_token_ids_gpu,
valid_sampled_tokens_count,
)
# Cache valid draft counts for scheduler-side trimming.
self._num_valid_draft_tokens = num_valid_draft_tokens
# Async D2H copy on a dedicated stream.
copy_num_valid_draft_tokens(
self._num_valid_draft_tokens_cpu,
self._num_valid_draft_tokens_copy_stream,
self._num_valid_draft_tokens_event,
self._num_valid_draft_tokens,
self.input_batch.num_reqs,
)
elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer)