[Core] NGram GPU Implementation compatible with Async Scheduler (#29184)
This commit is contained in:
@@ -98,7 +98,7 @@ def test_without_spec_decoding(
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=16)
|
||||
def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
|
||||
def test_with_eagle3_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test consistency and acceptance rates with some different combos of
|
||||
preemption, executor, async scheduling, prefill chunking,
|
||||
spec decoding model length.
|
||||
@@ -154,6 +154,42 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
|
||||
)
|
||||
|
||||
|
||||
def test_with_ngram_gpu_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test ngram_gpu speculative decoding with different configurations.
|
||||
|
||||
This test specifically validates ngram_gpu behavior with various:
|
||||
- Number of speculative tokens (2-6)
|
||||
- Prompt lookup window sizes (min/max)
|
||||
- Async scheduling enabled (as in production)
|
||||
- Different executors and chunking settings
|
||||
"""
|
||||
|
||||
# Variant with larger speculation window
|
||||
ngram_gpu_config = {
|
||||
"method": "ngram_gpu",
|
||||
"num_speculative_tokens": 3,
|
||||
"prompt_lookup_max": 3,
|
||||
"prompt_lookup_min": 2,
|
||||
}
|
||||
|
||||
# Test configurations covering various scenarios
|
||||
# test_preemption, executor, async_scheduling,
|
||||
# spec_config, test_prefill_chunking
|
||||
test_configs = [
|
||||
(False, "mp", False, None, False),
|
||||
(False, "mp", False, ngram_gpu_config, False),
|
||||
(True, "mp", False, ngram_gpu_config, True),
|
||||
(False, "mp", True, ngram_gpu_config, False),
|
||||
(True, "mp", True, ngram_gpu_config, False),
|
||||
(True, "uni", True, ngram_gpu_config, False),
|
||||
(True, "mp", True, ngram_gpu_config, True),
|
||||
]
|
||||
|
||||
# Use MODEL (Qwen) for ngram_gpu tests as it's lighter weight
|
||||
# and ngram_gpu doesn't require a specific draft model
|
||||
run_tests(monkeypatch, MODEL, test_configs, [{}])
|
||||
|
||||
|
||||
@dynamo_config.patch(cache_size_limit=16)
|
||||
def run_tests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
@@ -282,11 +318,12 @@ def run_test(
|
||||
else dict(gpu_memory_utilization=0.9)
|
||||
)
|
||||
spec_mml = (spec_config or {}).get("max_model_len")
|
||||
spec_method = (spec_config or {}).get("method", "none")
|
||||
test_config = (
|
||||
f"executor={executor}, preemption={test_preemption}, "
|
||||
f"async_sched={async_scheduling}, "
|
||||
f"chunk_prefill={test_prefill_chunking}, "
|
||||
f"spec_decoding={spec_decoding}, spec_mml={spec_mml}"
|
||||
f"spec_decoding={spec_decoding}, spec_method={spec_method}, spec_mml={spec_mml}"
|
||||
)
|
||||
print("-" * 80)
|
||||
print(f"---- TESTING {test_str}: {test_config}")
|
||||
@@ -294,7 +331,7 @@ def run_test(
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
max_model_len=4096,
|
||||
enable_chunked_prefill=test_prefill_chunking,
|
||||
# Force prefill chunking
|
||||
max_num_batched_tokens=48 if test_prefill_chunking else None,
|
||||
|
||||
@@ -183,6 +183,34 @@ def test_ngram_and_suffix_correctness(
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_scheduling", [True], ids=["async"])
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_ngram_gpu_default_with_async_scheduling(
|
||||
async_scheduling: bool,
|
||||
):
|
||||
"""
|
||||
Test ngram_gpu speculative decoding (k=3) correctness with and without
|
||||
async scheduling, validated via GSM8K accuracy.
|
||||
Uses Qwen/Qwen3-8B (ref GSM8K accuracy: 87%-92%).
|
||||
"""
|
||||
qwen3_model = "Qwen/Qwen3-8B"
|
||||
spec_llm = LLM(
|
||||
model=qwen3_model,
|
||||
speculative_config={
|
||||
"method": "ngram_gpu",
|
||||
"prompt_lookup_max": 3,
|
||||
"prompt_lookup_min": 2,
|
||||
"num_speculative_tokens": 2,
|
||||
},
|
||||
max_model_len=4096,
|
||||
async_scheduling=async_scheduling,
|
||||
)
|
||||
evaluate_llm_for_gsm8k(spec_llm, expected_accuracy_threshold=0.8)
|
||||
del spec_llm
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@single_gpu_only
|
||||
@large_gpu_mark(min_gb=20)
|
||||
def test_suffix_decoding_acceptance(
|
||||
|
||||
@@ -907,6 +907,13 @@ class VllmBackend:
|
||||
# Honors opt-outs such as CompilationMode.NONE or VLLM_DISABLE_COMPILE_CACHE.
|
||||
disable_cache = not is_compile_cache_enabled(self.inductor_config)
|
||||
|
||||
# TODO(patchy): ngram gpu kernel will cause vllm torch compile cache errors.
|
||||
is_ngram_gpu_enabled = (
|
||||
vllm_config.speculative_config is not None
|
||||
and vllm_config.speculative_config.use_ngram_gpu()
|
||||
)
|
||||
disable_cache = disable_cache or is_ngram_gpu_enabled
|
||||
|
||||
if disable_cache:
|
||||
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
|
||||
else:
|
||||
|
||||
@@ -47,6 +47,7 @@ MTPModelTypes = Literal[
|
||||
"step3p5_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
|
||||
NgramGPUTypes = Literal["ngram_gpu"]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
@@ -54,6 +55,7 @@ SpeculativeMethod = Literal[
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
NgramGPUTypes,
|
||||
]
|
||||
|
||||
|
||||
@@ -364,6 +366,8 @@ class SpeculativeConfig:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
elif self.method == "ngram_gpu":
|
||||
self.model = "ngram_gpu"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
elif self.method == "extract_hidden_states":
|
||||
@@ -374,8 +378,9 @@ class SpeculativeConfig:
|
||||
)
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
|
||||
if self.method in ("ngram", "ngram_gpu"):
|
||||
# Set default values if not provided
|
||||
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
@@ -832,6 +837,9 @@ class SpeculativeConfig:
|
||||
def uses_extract_hidden_states(self) -> bool:
|
||||
return self.method == "extract_hidden_states"
|
||||
|
||||
def use_ngram_gpu(self) -> bool:
|
||||
return self.method == "ngram_gpu"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = (
|
||||
|
||||
@@ -41,7 +41,7 @@ from .offload import OffloadConfig
|
||||
from .parallel import ParallelConfig
|
||||
from .profiler import ProfilerConfig
|
||||
from .scheduler import SchedulerConfig
|
||||
from .speculative import EagleModelTypes, SpeculativeConfig
|
||||
from .speculative import EagleModelTypes, NgramGPUTypes, SpeculativeConfig
|
||||
from .structured_outputs import StructuredOutputsConfig
|
||||
from .utils import SupportsHash, config, replace
|
||||
from .weight_transfer import WeightTransferConfig
|
||||
@@ -696,11 +696,13 @@ class VllmConfig:
|
||||
if self.speculative_config is not None:
|
||||
if (
|
||||
self.speculative_config.method not in get_args(EagleModelTypes)
|
||||
and self.speculative_config.method not in get_args(NgramGPUTypes)
|
||||
and self.speculative_config.method != "draft_model"
|
||||
):
|
||||
raise ValueError(
|
||||
"Currently, async scheduling is only supported "
|
||||
"with EAGLE/MTP/Draft Model kind of speculative decoding."
|
||||
"with EAGLE/MTP/Draft Model/NGram GPU kind of "
|
||||
"speculative decoding"
|
||||
)
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
@@ -718,6 +720,7 @@ class VllmConfig:
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.method not in get_args(EagleModelTypes)
|
||||
and self.speculative_config.method not in get_args(NgramGPUTypes)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Async scheduling not supported with %s-based "
|
||||
|
||||
@@ -385,6 +385,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
assert current_tool_call is not None
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
@@ -489,6 +490,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
assert isinstance(current_tool_call, dict)
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
|
||||
660
vllm/v1/spec_decode/ngram_proposer_gpu.py
Normal file
660
vllm/v1/spec_decode/ngram_proposer_gpu.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
|
||||
|
||||
This version uses a fully vectorized approach with unfold and argmax for
|
||||
finding the first match across all sequences in parallel.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
CompilationMode,
|
||||
CUDAGraphMode,
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
|
||||
@support_torch_compile()
|
||||
class NgramGPUKernel(nn.Module):
|
||||
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
|
||||
|
||||
def __init__(
|
||||
self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert vllm_config.speculative_config is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
||||
|
||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
||||
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.device = device
|
||||
|
||||
def _find_first_and_extract_all_n_parallel(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
seq_lengths: torch.Tensor,
|
||||
min_ngram_len: int,
|
||||
max_ngram_len: int,
|
||||
num_draft_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Find suffix n-gram matches and extract following tokens.
|
||||
Searches for the earliest prior occurrence of the trailing n-gram,
|
||||
tries multiple lengths, and picks the longest valid match.
|
||||
|
||||
Args:
|
||||
token_ids: Token IDs for each sequence
|
||||
seq_lengths: Actual length of each sequence (excluding padding)
|
||||
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
|
||||
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
|
||||
num_draft_tokens: Number of tokens to extract after match (k)
|
||||
|
||||
Returns:
|
||||
Draft token predictions; -1 means invalid/no match.
|
||||
"""
|
||||
batch_size = token_ids.shape[0]
|
||||
max_seq_len = token_ids.shape[1]
|
||||
device = token_ids.device
|
||||
num_ngram_sizes = max_ngram_len - min_ngram_len + 1
|
||||
|
||||
# All n-gram sizes to try.
|
||||
ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
|
||||
batch_indices = torch.arange(batch_size, device=device)
|
||||
|
||||
# Earliest match per (sequence, ngram_len); -1 means no match.
|
||||
first_match_positions = torch.full(
|
||||
(batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
|
||||
# Sliding windows of size ngram_len; unfold is O(1) view.
|
||||
search_windows = token_ids.unfold(1, ngram_len, 1)
|
||||
num_windows = search_windows.shape[1]
|
||||
|
||||
# Trailing suffix (last ngram_len tokens) for each sequence.
|
||||
suffix_starts = seq_lengths - ngram_len
|
||||
suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
|
||||
ngram_len, device=device
|
||||
)
|
||||
suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
|
||||
|
||||
# Window matches for each sequence.
|
||||
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
|
||||
|
||||
# Match must leave room for at least one draft token.
|
||||
max_valid_suffix_start = seq_lengths - ngram_len - 1
|
||||
window_positions = torch.arange(num_windows, device=device)
|
||||
valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
|
||||
final_matches = matches & valid_mask
|
||||
|
||||
# Find earliest match (argmax=0 when empty; verify with has_match).
|
||||
first_match_idx = torch.argmax(final_matches.int(), dim=1)
|
||||
has_match = final_matches[batch_indices, first_match_idx]
|
||||
|
||||
# Store valid match positions (window index = position).
|
||||
first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)
|
||||
|
||||
# Select the longest n-gram with a match.
|
||||
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
|
||||
best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back
|
||||
|
||||
# Match position for the best n-gram.
|
||||
best_match_pos = first_match_positions[batch_indices, best_ngram_idx]
|
||||
|
||||
# Avoid data-dependent branching.
|
||||
has_any_match = best_match_pos >= 0
|
||||
|
||||
# Length of the best matching n-gram.
|
||||
best_ngram_lengths = ngram_lengths[best_ngram_idx]
|
||||
|
||||
# Start position right after the matched suffix.
|
||||
draft_start = torch.where(
|
||||
has_any_match,
|
||||
best_match_pos + best_ngram_lengths,
|
||||
torch.zeros_like(best_match_pos),
|
||||
)
|
||||
tokens_available = seq_lengths - draft_start
|
||||
|
||||
# Gather indices for draft tokens.
|
||||
draft_indices = draft_start.unsqueeze(1) + torch.arange(
|
||||
num_draft_tokens, device=device
|
||||
)
|
||||
draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)
|
||||
|
||||
# Extract draft tokens; gather always runs.
|
||||
draft_tokens = torch.gather(token_ids, 1, draft_indices)
|
||||
|
||||
# Mask positions beyond available tokens.
|
||||
position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
|
||||
valid_positions = position_indices < tokens_available.unsqueeze(1)
|
||||
|
||||
draft_tokens = torch.where(
|
||||
valid_positions,
|
||||
draft_tokens,
|
||||
torch.full_like(draft_tokens, -1),
|
||||
)
|
||||
|
||||
# If no match, mask all positions.
|
||||
draft_tokens = torch.where(
|
||||
has_any_match.unsqueeze(1),
|
||||
draft_tokens,
|
||||
torch.full_like(draft_tokens, -1),
|
||||
)
|
||||
|
||||
return draft_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
num_tokens_no_spec: torch.Tensor,
|
||||
token_ids_gpu: torch.Tensor,
|
||||
combined_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for N-gram proposal using GPU tensor operations.
|
||||
|
||||
Args:
|
||||
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
|
||||
token_ids_gpu: Token IDs [batch_size, max_len]
|
||||
combined_mask: Whether each sequence is valid for spec decode [batch_size]
|
||||
|
||||
Returns:
|
||||
draft_tokens: [batch_size, k] on GPU
|
||||
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
|
||||
leading valid (non -1) tokens per request.
|
||||
"""
|
||||
|
||||
device = token_ids_gpu.device
|
||||
|
||||
# Infer batch size to preserve dynamic shape.
|
||||
actual_batch_size = token_ids_gpu.shape[0]
|
||||
|
||||
# Allocate in forward so torch.compile can optimize.
|
||||
# NOTE(patchy): Do NOT pre-allocate this as a buffer
|
||||
# it breaks torch.compile
|
||||
draft_tokens = torch.full(
|
||||
(actual_batch_size, self.k), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
results = self._find_first_and_extract_all_n_parallel(
|
||||
token_ids_gpu,
|
||||
num_tokens_no_spec,
|
||||
min_ngram_len=self.min_n,
|
||||
max_ngram_len=self.max_n,
|
||||
num_draft_tokens=self.k,
|
||||
)
|
||||
|
||||
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
|
||||
|
||||
# Count leading contiguous valid (non -1) tokens per request.
|
||||
is_valid = draft_tokens != -1 # [batch, k]
|
||||
cum_valid = is_valid.int().cumsum(dim=1) # [batch, k]
|
||||
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
|
||||
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
|
||||
|
||||
return draft_tokens, num_valid_draft_tokens
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
"""No model to load for N-gram proposer."""
|
||||
pass
|
||||
|
||||
|
||||
class NgramProposerGPU:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
||||
assert vllm_config.speculative_config is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
||||
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
||||
|
||||
compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
custom_ops=["none"],
|
||||
splitting_ops=[],
|
||||
compile_sizes=[],
|
||||
inductor_compile_config={
|
||||
"enable_auto_functionalized_v2": False,
|
||||
"max_autotune": True,
|
||||
"aggressive_fusion": True,
|
||||
"triton.autotune_pointwise": True,
|
||||
"coordinate_descent_tuning": True,
|
||||
"use_mixed_mm": False,
|
||||
},
|
||||
cudagraph_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
model_config = vllm_config.model_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
self.vllm_config = VllmConfig(
|
||||
compilation_config=compilation_config,
|
||||
model_config=model_config,
|
||||
speculative_config=speculative_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
||||
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
||||
self.k = vllm_config.speculative_config.num_speculative_tokens
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.device = device
|
||||
|
||||
self.kernel = NgramGPUKernel(
|
||||
vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
|
||||
)
|
||||
self.kernel.to(device)
|
||||
self.kernel.eval()
|
||||
|
||||
self._dummy_run()
|
||||
|
||||
def _dummy_run(self):
|
||||
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
|
||||
batch_size=self.max_num_seqs,
|
||||
max_seq_len=self.max_model_len,
|
||||
pattern_len=self.k,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)
|
||||
|
||||
for _ in range(3):
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
_, _ = self.kernel(num_tokens, token_ids, combined_mask)
|
||||
|
||||
def _generate_dummy_data(
|
||||
self,
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
pattern_len: int,
|
||||
device: str = "cuda",
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Generate random test data with n-gram repetitions.
|
||||
|
||||
Args:
|
||||
batch_size: Number of sequences in the batch
|
||||
max_seq_len: Maximum sequence length
|
||||
pattern_len: Length of patterns to inject for matching
|
||||
device: Device to place tensors on
|
||||
|
||||
Returns:
|
||||
token_ids: [batch_size, max_seq_len] tensor
|
||||
num_tokens: [batch_size] tensor
|
||||
sampled_flags: [batch_size] bool tensor
|
||||
valid_mask: [batch_size] bool tensor
|
||||
"""
|
||||
token_ids = torch.zeros(
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
num_tokens = torch.randint(
|
||||
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
|
||||
|
||||
return token_ids, num_tokens, sampled_flags, valid_mask
|
||||
|
||||
def propose(
|
||||
self,
|
||||
num_tokens_no_spec: torch.Tensor, # [batch_size]
|
||||
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
|
||||
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
|
||||
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Propose draft tokens using GPU-accelerated n-gram matching.
|
||||
|
||||
Scatter sampled tokens into `token_ids_gpu`, compute temporary
|
||||
updated lengths, then run the kernel.
|
||||
|
||||
Args:
|
||||
num_tokens_no_spec: Number of tokens per sequence (read-only)
|
||||
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
|
||||
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
|
||||
valid_sampled_tokens_count: Count of valid tokens per sequence
|
||||
|
||||
Returns:
|
||||
draft_tokens: Proposed draft token IDs [batch_size, k]
|
||||
num_valid_draft_tokens: Count of leading valid draft tokens
|
||||
per request [batch_size]
|
||||
"""
|
||||
assert token_ids_gpu.device == self.device
|
||||
assert num_tokens_no_spec.device == self.device
|
||||
|
||||
batch_size = num_tokens_no_spec.shape[0]
|
||||
max_seq_len = token_ids_gpu.shape[1]
|
||||
max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1
|
||||
|
||||
# Scatter newly sampled tokens into token_ids_gpu.
|
||||
offsets = torch.arange(max_new_tokens, device=self.device)
|
||||
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
|
||||
valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
|
||||
1
|
||||
)
|
||||
in_bounds = write_positions < max_seq_len
|
||||
scatter_mask = (
|
||||
valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
|
||||
)
|
||||
|
||||
write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
|
||||
existing_values = token_ids_gpu.gather(1, write_positions_long)
|
||||
|
||||
tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
|
||||
tokens_to_scatter = torch.where(
|
||||
scatter_mask,
|
||||
tokens_cast,
|
||||
existing_values,
|
||||
)
|
||||
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
|
||||
|
||||
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
|
||||
|
||||
# Compute validity masks.
|
||||
sampled_flags = valid_sampled_tokens_count > 0
|
||||
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
|
||||
|
||||
with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
|
||||
draft_tokens, num_valid_draft_tokens = self.kernel(
|
||||
num_tokens_tmp,
|
||||
token_ids_gpu,
|
||||
combined_mask,
|
||||
)
|
||||
|
||||
return draft_tokens, num_valid_draft_tokens
|
||||
|
||||
def update_token_ids_ngram(
|
||||
self,
|
||||
sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
gpu_input_batch: InputBatch,
|
||||
token_ids_gpu: torch.Tensor,
|
||||
num_tokens_no_spec: torch.Tensor,
|
||||
discard_request_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare speculative decoding inputs on device:
|
||||
compute next token ids and valid counts, honoring discarded requests
|
||||
and rejected tokens, without CPU-GPU sync.
|
||||
"""
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
|
||||
if isinstance(sampled_token_ids, list):
|
||||
# When disable_padded_drafter_batch=True, sampled_token_ids is
|
||||
# an irregular list[list[int]] where sublists may have different
|
||||
# lengths (including empty lists for discarded requests).
|
||||
# Pad all sublists to the same length with -1 before converting
|
||||
# to tensor.
|
||||
max_len = max(
|
||||
(len(sublist) for sublist in sampled_token_ids),
|
||||
default=0,
|
||||
)
|
||||
# Ensure at least length 1 for tensor creation
|
||||
max_len = max(max_len, 1)
|
||||
padded_list = [
|
||||
sublist + [-1] * (max_len - len(sublist))
|
||||
for sublist in sampled_token_ids
|
||||
]
|
||||
sampled_token_ids = torch.tensor(
|
||||
padded_list, dtype=torch.int32, device=self.device
|
||||
)
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), (
|
||||
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
|
||||
)
|
||||
|
||||
# Backup last valid token before speculative tokens.
|
||||
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
|
||||
backup_next_token_ids = torch.gather(
|
||||
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
# Invalidate sampled tokens for discarded requests.
|
||||
discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
|
||||
valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)
|
||||
|
||||
# Mask valid tokens within each request.
|
||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
||||
)
|
||||
|
||||
# Count valid tokens per request.
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Rightmost valid index per row.
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Last valid token from each row; undefined if none.
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
# Use last token if valid; otherwise fallback to backup.
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1,
|
||||
selected_tokens,
|
||||
backup_next_token_ids,
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
|
||||
|
||||
def load_model(self, *args, **kwargs):
|
||||
self.kernel.load_model(*args, **kwargs)
|
||||
|
||||
|
||||
def update_scheduler_for_invalid_drafts(
|
||||
num_valid_draft_tokens_event: torch.cuda.Event,
|
||||
num_valid_draft_tokens_cpu: torch.Tensor,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
req_id_to_index: dict[str, int],
|
||||
) -> None:
|
||||
"""Trim invalid speculative slots using per-request valid draft counts.
|
||||
|
||||
Args:
|
||||
num_valid_draft_tokens_event: Event for async D2H completion.
|
||||
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
|
||||
scheduler_output: Scheduler metadata to update in-place.
|
||||
req_id_to_index: Request-id to batch-index mapping.
|
||||
"""
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
num_valid_draft_tokens_event.synchronize()
|
||||
|
||||
for req_id in req_data.req_ids:
|
||||
req_index = req_id_to_index.get(req_id)
|
||||
if req_index is None:
|
||||
continue
|
||||
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
|
||||
if spec_token_ids is None:
|
||||
continue
|
||||
|
||||
scheduled_k = len(spec_token_ids)
|
||||
|
||||
valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
|
||||
valid_k = max(0, min(valid_k, scheduled_k))
|
||||
|
||||
tokens_to_trim = scheduled_k - valid_k
|
||||
scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
|
||||
scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim
|
||||
|
||||
if valid_k == 0:
|
||||
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
|
||||
else:
|
||||
scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
|
||||
:valid_k
|
||||
]
|
||||
|
||||
|
||||
def update_ngram_gpu_tensors_incremental(
|
||||
input_batch: InputBatch,
|
||||
token_ids_gpu_tensor: torch.Tensor,
|
||||
num_tokens_no_spec_gpu: torch.Tensor,
|
||||
new_reqs: list[CachedRequestState],
|
||||
device: torch.device,
|
||||
_pinned_idx_buf: torch.Tensor,
|
||||
_pinned_val_buf: torch.Tensor,
|
||||
) -> None:
|
||||
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
|
||||
for ngram GPU proposer.
|
||||
"""
|
||||
prev_req_id_to_index = input_batch.prev_req_id_to_index
|
||||
curr_req_id_to_index = input_batch.req_id_to_index
|
||||
|
||||
if not curr_req_id_to_index:
|
||||
return
|
||||
|
||||
active_indices = list(curr_req_id_to_index.values())
|
||||
n_active = len(active_indices)
|
||||
|
||||
# Use resident pinned buffers to avoid per-call allocation.
|
||||
active_idx_cpu = _pinned_idx_buf[:n_active]
|
||||
active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))
|
||||
|
||||
active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)
|
||||
|
||||
new_req_ids = {req.req_id for req in new_reqs}
|
||||
|
||||
# First run, no previous state.
|
||||
if prev_req_id_to_index is None:
|
||||
for idx in active_indices:
|
||||
num_tokens = input_batch.num_tokens_no_spec[idx]
|
||||
if num_tokens > 0:
|
||||
token_ids_gpu_tensor[idx, :num_tokens].copy_(
|
||||
input_batch.token_ids_cpu_tensor[idx, :num_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
_sync_num_tokens(
|
||||
input_batch,
|
||||
num_tokens_no_spec_gpu,
|
||||
active_idx_cpu,
|
||||
active_idx_gpu,
|
||||
n_active,
|
||||
device,
|
||||
_pinned_val_buf,
|
||||
)
|
||||
return
|
||||
|
||||
# Detect index changes for reorder.
|
||||
reorder_src: list[int] = []
|
||||
reorder_dst: list[int] = []
|
||||
|
||||
for req_id, curr_idx in curr_req_id_to_index.items():
|
||||
if req_id in new_req_ids:
|
||||
continue
|
||||
prev_idx = prev_req_id_to_index.get(req_id)
|
||||
if prev_idx is not None and prev_idx != curr_idx:
|
||||
reorder_src.append(prev_idx)
|
||||
reorder_dst.append(curr_idx)
|
||||
|
||||
if reorder_src:
|
||||
src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
|
||||
dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)
|
||||
|
||||
temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
|
||||
temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()
|
||||
|
||||
token_ids_gpu_tensor[dst_tensor] = temp_token_ids
|
||||
num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens
|
||||
|
||||
# Full copy for new/resumed requests.
|
||||
for req_state in new_reqs:
|
||||
new_req_idx = curr_req_id_to_index.get(req_state.req_id)
|
||||
if new_req_idx is None:
|
||||
continue
|
||||
|
||||
num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
|
||||
if num_tokens > 0:
|
||||
token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
|
||||
input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
# Always batch-sync sequence lengths from CPU for ALL active requests.
|
||||
_sync_num_tokens(
|
||||
input_batch,
|
||||
num_tokens_no_spec_gpu,
|
||||
active_idx_cpu,
|
||||
active_idx_gpu,
|
||||
n_active,
|
||||
device,
|
||||
_pinned_val_buf,
|
||||
)
|
||||
|
||||
|
||||
def _sync_num_tokens(
|
||||
input_batch: InputBatch,
|
||||
num_tokens_no_spec_gpu: torch.Tensor,
|
||||
active_idx_cpu: torch.Tensor,
|
||||
active_idx_gpu: torch.Tensor,
|
||||
n_active: int,
|
||||
device: torch.device,
|
||||
_pinned_val_buf: torch.Tensor,
|
||||
) -> None:
|
||||
"""Batch-sync GPU sequence lengths from CPU source of truth.
|
||||
|
||||
Inputs:
|
||||
input_batch: Batch container with CPU length tensor.
|
||||
num_tokens_no_spec_gpu: Destination GPU length tensor.
|
||||
active_idx_cpu: Active request indices on CPU.
|
||||
active_idx_gpu: Active request indices on GPU.
|
||||
n_active: Number of active requests.
|
||||
device: Target CUDA device.
|
||||
_pinned_val_buf: Resident pinned int32 staging buffer.
|
||||
Outputs:
|
||||
None (updates num_tokens_no_spec_gpu in-place).
|
||||
"""
|
||||
src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
|
||||
vals = _pinned_val_buf[:n_active]
|
||||
vals.copy_(src_cpu.index_select(0, active_idx_cpu))
|
||||
|
||||
num_tokens_no_spec_gpu.index_copy_(
|
||||
0,
|
||||
active_idx_gpu,
|
||||
vals.to(device=device, non_blocking=True),
|
||||
)
|
||||
|
||||
|
||||
def copy_num_valid_draft_tokens(
|
||||
num_valid_draft_tokens_cpu: torch.Tensor,
|
||||
num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
|
||||
num_valid_draft_tokens_event: torch.cuda.Event,
|
||||
num_valid_draft_tokens: torch.Tensor | None,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""
|
||||
Async D2H copy of per-request valid draft counts.
|
||||
"""
|
||||
if num_valid_draft_tokens is None:
|
||||
return
|
||||
|
||||
num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
|
||||
if num_reqs_to_copy <= 0:
|
||||
return
|
||||
|
||||
default_stream = torch.cuda.current_stream()
|
||||
with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
|
||||
num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
|
||||
num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
|
||||
num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
|
||||
)
|
||||
num_valid_draft_tokens_event.record()
|
||||
@@ -127,7 +127,13 @@ class InputBatch:
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs,),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.num_tokens_no_spec = self.num_tokens_no_spec_cpu_tensor.numpy()
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_computed_tokens_cpu_tensor = torch.zeros(
|
||||
(max_num_reqs,),
|
||||
|
||||
@@ -10,7 +10,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, replace
|
||||
from functools import reduce
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast
|
||||
|
||||
@@ -164,6 +164,12 @@ from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesProposer
|
||||
from vllm.v1.spec_decode.medusa import MedusaProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer_gpu import (
|
||||
NgramProposerGPU,
|
||||
copy_num_valid_draft_tokens,
|
||||
update_ngram_gpu_tensors_incremental,
|
||||
update_scheduler_for_invalid_drafts,
|
||||
)
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
@@ -424,7 +430,7 @@ class GPUModelRunner(
|
||||
|
||||
# Broadcast PP output for external_launcher (torchrun)
|
||||
# to make sure we are synced across pp ranks
|
||||
# TODO: Support overlapping mirco-batches
|
||||
# TODO: Support overlapping micro-batches
|
||||
# https://github.com/vllm-project/vllm/issues/18019
|
||||
self.broadcast_pp_output = (
|
||||
self.parallel_config.distributed_executor_backend == "external_launcher"
|
||||
@@ -493,6 +499,7 @@ class GPUModelRunner(
|
||||
if self.speculative_config and get_pp_group().is_last_rank:
|
||||
self.drafter: (
|
||||
NgramProposer # noqa: F823
|
||||
| NgramProposerGPU
|
||||
| SuffixDecodingProposer
|
||||
| EagleProposer
|
||||
| DraftModelProposer
|
||||
@@ -509,6 +516,23 @@ class GPUModelRunner(
|
||||
device=self.device,
|
||||
runner=self,
|
||||
)
|
||||
elif self.speculative_config.use_ngram_gpu():
|
||||
self.drafter = NgramProposerGPU(self.vllm_config, self.device, self)
|
||||
self.num_tokens_no_spec_gpu = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
self.token_ids_gpu_tensor = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.max_model_len,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self._ngram_pinned_idx_buf = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.long, pin_memory=True
|
||||
)
|
||||
self._ngram_pinned_val_buf = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
elif self.speculative_config.method == "suffix":
|
||||
self.drafter = SuffixDecodingProposer(self.vllm_config)
|
||||
elif self.speculative_config.use_eagle():
|
||||
@@ -564,7 +588,7 @@ class GPUModelRunner(
|
||||
)
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
# We need to use the encoder length for encoder-decoer
|
||||
# We need to use the encoder length for encoder-decoder
|
||||
# because of KV cache for cross-attention.
|
||||
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
@@ -721,6 +745,21 @@ class GPUModelRunner(
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: list[list[int]] | torch.Tensor | None = None
|
||||
# N-gram GPU path: async D2H buffer/event for per-request valid draft counts.
|
||||
self._num_valid_draft_tokens: torch.Tensor | None = None
|
||||
self._num_valid_draft_tokens_cpu: torch.Tensor | None = None
|
||||
self._num_valid_draft_tokens_event: torch.cuda.Event | None = None
|
||||
self._num_valid_draft_tokens_copy_stream: torch.cuda.Stream | None = None
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.use_ngram_gpu()
|
||||
):
|
||||
self._num_valid_draft_tokens_cpu = torch.empty(
|
||||
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
|
||||
)
|
||||
self._num_valid_draft_tokens_event = torch.cuda.Event()
|
||||
self._num_valid_draft_tokens_copy_stream = torch.cuda.Stream()
|
||||
|
||||
self._draft_token_req_ids: list[str] | None = None
|
||||
self.transfer_event = torch.Event()
|
||||
self.sampled_token_ids_pinned_cpu = torch.empty(
|
||||
@@ -992,6 +1031,13 @@ class GPUModelRunner(
|
||||
for req_id in unscheduled_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
is_ngram_gpu = (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.use_ngram_gpu()
|
||||
)
|
||||
if is_ngram_gpu:
|
||||
ngram_gpu_new_reqs: list[CachedRequestState] = []
|
||||
|
||||
reqs_to_add: list[CachedRequestState] = []
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
@@ -1054,12 +1100,31 @@ class GPUModelRunner(
|
||||
self._init_xdrope_positions(req_state)
|
||||
|
||||
reqs_to_add.append(req_state)
|
||||
# Track new requests for ngram_gpu full tensor copy
|
||||
if is_ngram_gpu:
|
||||
ngram_gpu_new_reqs.append(req_state)
|
||||
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
# Save scheduler-allocated spec lengths before trimming so
|
||||
# prev_num_draft_len keeps the optimistic count for rejection correction.
|
||||
original_num_spec_per_req: dict[str, int] = {}
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.use_ngram_gpu()
|
||||
):
|
||||
for req_id, toks in scheduled_spec_tokens.items():
|
||||
original_num_spec_per_req[req_id] = len(toks)
|
||||
update_scheduler_for_invalid_drafts(
|
||||
self._num_valid_draft_tokens_event,
|
||||
self._num_valid_draft_tokens_cpu,
|
||||
scheduler_output,
|
||||
self.input_batch.req_id_to_index,
|
||||
)
|
||||
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
# then use it to update actual num_computed_tokens of each request.
|
||||
valid_sampled_token_count = self._get_valid_sampled_token_count()
|
||||
@@ -1076,13 +1141,13 @@ class GPUModelRunner(
|
||||
# prev_num_draft_len is used in async scheduling mode with
|
||||
# spec decode. it indicates if need to update num_computed_tokens
|
||||
# of the request. for example:
|
||||
# fist step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# first step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# prev_num_draft_len = 0.
|
||||
# second step: num_computed_tokens = 100(prompt length),
|
||||
# spec_tokens = [a,b], prev_num_draft_len = 0.
|
||||
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
|
||||
# prev_num_draft_len = 2.
|
||||
# num_computed_tokens in first step and second step does't contain
|
||||
# num_computed_tokens in first step and second step doesn't contain
|
||||
# the spec tokens length, but in third step it contains the
|
||||
# spec tokens length. we only need to update num_computed_tokens
|
||||
# when prev_num_draft_len > 0.
|
||||
@@ -1096,6 +1161,9 @@ class GPUModelRunner(
|
||||
num_computed_tokens -= num_rejected
|
||||
req_state.output_token_ids.extend([-1] * num_accepted)
|
||||
|
||||
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
|
||||
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
|
||||
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
|
||||
@@ -1156,6 +1224,9 @@ class GPUModelRunner(
|
||||
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
|
||||
|
||||
reqs_to_add.append(req_state)
|
||||
# Track resumed requests for ngram_gpu full tensor copy
|
||||
if is_ngram_gpu:
|
||||
ngram_gpu_new_reqs.append(req_state)
|
||||
continue
|
||||
|
||||
# Update the persistent batch.
|
||||
@@ -1176,6 +1247,11 @@ class GPUModelRunner(
|
||||
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
|
||||
# Restore scheduler-side draft count after ngram trimming.
|
||||
if original_num_spec_per_req:
|
||||
orig = original_num_spec_per_req.get(req_id, 0)
|
||||
if orig != req_state.prev_num_draft_len:
|
||||
req_state.prev_num_draft_len = orig
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@@ -1190,6 +1266,18 @@ class GPUModelRunner(
|
||||
# Refresh batch metadata with any pending updates.
|
||||
self.input_batch.refresh_metadata()
|
||||
|
||||
# Incrementally update ngram_gpu tensors after batch is stable
|
||||
if is_ngram_gpu:
|
||||
update_ngram_gpu_tensors_incremental(
|
||||
self.input_batch,
|
||||
self.token_ids_gpu_tensor,
|
||||
self.num_tokens_no_spec_gpu,
|
||||
ngram_gpu_new_reqs,
|
||||
self.device,
|
||||
_pinned_idx_buf=self._ngram_pinned_idx_buf,
|
||||
_pinned_val_buf=self._ngram_pinned_val_buf,
|
||||
)
|
||||
|
||||
def _update_states_after_model_execute(
|
||||
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
|
||||
) -> None:
|
||||
@@ -3412,6 +3500,23 @@ class GPUModelRunner(
|
||||
else:
|
||||
logger.error("RoutedExpertsCapturer not initialized.")
|
||||
|
||||
# If ngram_gpu is used, we need to copy the scheduler_output to avoid
|
||||
# the modification has influence on the scheduler_output in engine core process.
|
||||
# The replace is much faster than deepcopy.
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.speculative_config.use_ngram_gpu()
|
||||
):
|
||||
num_scheduled_tokens_copy = scheduler_output.num_scheduled_tokens.copy()
|
||||
spec_decode_tokens_copy = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.copy()
|
||||
)
|
||||
scheduler_output = replace(
|
||||
scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens_copy,
|
||||
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
|
||||
)
|
||||
|
||||
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
|
||||
get_kv_transfer_group().handle_preemptions(
|
||||
scheduler_output.preempted_req_ids
|
||||
@@ -3825,6 +3930,32 @@ class GPUModelRunner(
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
self._draft_token_ids = torch.zeros(
|
||||
1, device=self.device, dtype=torch.int32
|
||||
).expand(len(self.input_batch.req_ids), self.num_spec_tokens)
|
||||
self._copy_draft_token_ids_to_cpu(scheduler_output, zeros_only=True)
|
||||
elif (
|
||||
spec_config.use_ngram_gpu()
|
||||
and not spec_config.disable_padded_drafter_batch
|
||||
):
|
||||
assert isinstance(self.drafter, NgramProposerGPU)
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
if input_fits_in_drafter:
|
||||
propose_draft_token_ids(sampled_token_ids)
|
||||
elif self.valid_sampled_token_count_event is not None:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count, _ = (
|
||||
self.drafter.update_token_ids_ngram(
|
||||
sampled_token_ids,
|
||||
self.input_batch,
|
||||
self.token_ids_gpu_tensor,
|
||||
self.num_tokens_no_spec_gpu,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
# Since we couldn't run the drafter,
|
||||
# just use zeros for the draft tokens.
|
||||
self._draft_token_ids = torch.zeros(
|
||||
@@ -4064,6 +4195,52 @@ class GPUModelRunner(
|
||||
self.input_batch.token_ids_cpu,
|
||||
slot_mappings=slot_mappings,
|
||||
)
|
||||
if isinstance(self.drafter, NgramProposer):
|
||||
assert isinstance(sampled_token_ids, list), (
|
||||
"sampled_token_ids should be a python list when ngram is used."
|
||||
)
|
||||
draft_token_ids = self.drafter.propose(
|
||||
sampled_token_ids,
|
||||
self.input_batch.num_tokens_no_spec,
|
||||
self.input_batch.token_ids_cpu,
|
||||
)
|
||||
elif spec_config.use_ngram_gpu():
|
||||
assert isinstance(self.drafter, NgramProposerGPU)
|
||||
(
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
valid_sampled_token_ids_gpu,
|
||||
) = self.drafter.update_token_ids_ngram(
|
||||
sampled_token_ids,
|
||||
self.input_batch,
|
||||
self.token_ids_gpu_tensor,
|
||||
self.num_tokens_no_spec_gpu,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
batch_size = next_token_ids.shape[0]
|
||||
|
||||
draft_token_ids, num_valid_draft_tokens = self.drafter.propose(
|
||||
self.num_tokens_no_spec_gpu[:batch_size],
|
||||
self.token_ids_gpu_tensor[:batch_size],
|
||||
valid_sampled_token_ids_gpu,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
|
||||
# Cache valid draft counts for scheduler-side trimming.
|
||||
self._num_valid_draft_tokens = num_valid_draft_tokens
|
||||
|
||||
# Async D2H copy on a dedicated stream.
|
||||
copy_num_valid_draft_tokens(
|
||||
self._num_valid_draft_tokens_cpu,
|
||||
self._num_valid_draft_tokens_copy_stream,
|
||||
self._num_valid_draft_tokens_event,
|
||||
self._num_valid_draft_tokens,
|
||||
self.input_batch.num_reqs,
|
||||
)
|
||||
elif spec_config.method == "suffix":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, SuffixDecodingProposer)
|
||||
|
||||
Reference in New Issue
Block a user