661 lines
24 KiB
Python
661 lines
24 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
GPU-accelerated N-gram proposer using fully async PyTorch tensor operations.
|
|
|
|
This version uses a fully vectorized approach with unfold and argmax for
|
|
finding the first match across all sequences in parallel.
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import (
|
|
CompilationConfig,
|
|
CompilationMode,
|
|
CUDAGraphMode,
|
|
VllmConfig,
|
|
)
|
|
from vllm.forward_context import set_forward_context
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
from vllm.v1.utils import record_function_or_nullcontext
|
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
|
|
|
|
@support_torch_compile()
|
|
class NgramGPUKernel(nn.Module):
|
|
"""GPU-accelerated N-gram proposer using fully async tensor operations."""
|
|
|
|
def __init__(
|
|
self, vllm_config: VllmConfig, prefix: str = "", device: torch.device = "cuda"
|
|
):
|
|
super().__init__()
|
|
|
|
assert vllm_config.speculative_config is not None
|
|
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
|
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
|
|
|
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
|
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
|
self.k = vllm_config.speculative_config.num_speculative_tokens
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
|
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
|
self.device = device
|
|
|
|
def _find_first_and_extract_all_n_parallel(
|
|
self,
|
|
token_ids: torch.Tensor,
|
|
seq_lengths: torch.Tensor,
|
|
min_ngram_len: int,
|
|
max_ngram_len: int,
|
|
num_draft_tokens: int,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Find suffix n-gram matches and extract following tokens.
|
|
Searches for the earliest prior occurrence of the trailing n-gram,
|
|
tries multiple lengths, and picks the longest valid match.
|
|
|
|
Args:
|
|
token_ids: Token IDs for each sequence
|
|
seq_lengths: Actual length of each sequence (excluding padding)
|
|
min_ngram_len: Minimum n-gram size to search for (e.g., 2)
|
|
max_ngram_len: Maximum n-gram size to search for (e.g., 5)
|
|
num_draft_tokens: Number of tokens to extract after match (k)
|
|
|
|
Returns:
|
|
Draft token predictions; -1 means invalid/no match.
|
|
"""
|
|
batch_size = token_ids.shape[0]
|
|
max_seq_len = token_ids.shape[1]
|
|
device = token_ids.device
|
|
num_ngram_sizes = max_ngram_len - min_ngram_len + 1
|
|
|
|
# All n-gram sizes to try.
|
|
ngram_lengths = torch.arange(min_ngram_len, max_ngram_len + 1, device=device)
|
|
batch_indices = torch.arange(batch_size, device=device)
|
|
|
|
# Earliest match per (sequence, ngram_len); -1 means no match.
|
|
first_match_positions = torch.full(
|
|
(batch_size, num_ngram_sizes), -1, dtype=torch.long, device=device
|
|
)
|
|
|
|
for i, ngram_len in enumerate(range(min_ngram_len, max_ngram_len + 1)):
|
|
# Sliding windows of size ngram_len; unfold is O(1) view.
|
|
search_windows = token_ids.unfold(1, ngram_len, 1)
|
|
num_windows = search_windows.shape[1]
|
|
|
|
# Trailing suffix (last ngram_len tokens) for each sequence.
|
|
suffix_starts = seq_lengths - ngram_len
|
|
suffix_indices = suffix_starts.unsqueeze(1) + torch.arange(
|
|
ngram_len, device=device
|
|
)
|
|
suffix = torch.gather(token_ids, 1, suffix_indices.clamp(min=0))
|
|
|
|
# Window matches for each sequence.
|
|
matches = (search_windows == suffix.unsqueeze(1)).all(dim=-1)
|
|
|
|
# Match must leave room for at least one draft token.
|
|
max_valid_suffix_start = seq_lengths - ngram_len - 1
|
|
window_positions = torch.arange(num_windows, device=device)
|
|
valid_mask = window_positions <= max_valid_suffix_start.unsqueeze(1)
|
|
final_matches = matches & valid_mask
|
|
|
|
# Find earliest match (argmax=0 when empty; verify with has_match).
|
|
first_match_idx = torch.argmax(final_matches.int(), dim=1)
|
|
has_match = final_matches[batch_indices, first_match_idx]
|
|
|
|
# Store valid match positions (window index = position).
|
|
first_match_positions[:, i] = torch.where(has_match, first_match_idx, -1)
|
|
|
|
# Select the longest n-gram with a match.
|
|
best_ngram_idx = (first_match_positions >= 0).int().flip(dims=[1]).argmax(dim=1)
|
|
best_ngram_idx = num_ngram_sizes - 1 - best_ngram_idx # Flip back
|
|
|
|
# Match position for the best n-gram.
|
|
best_match_pos = first_match_positions[batch_indices, best_ngram_idx]
|
|
|
|
# Avoid data-dependent branching.
|
|
has_any_match = best_match_pos >= 0
|
|
|
|
# Length of the best matching n-gram.
|
|
best_ngram_lengths = ngram_lengths[best_ngram_idx]
|
|
|
|
# Start position right after the matched suffix.
|
|
draft_start = torch.where(
|
|
has_any_match,
|
|
best_match_pos + best_ngram_lengths,
|
|
torch.zeros_like(best_match_pos),
|
|
)
|
|
tokens_available = seq_lengths - draft_start
|
|
|
|
# Gather indices for draft tokens.
|
|
draft_indices = draft_start.unsqueeze(1) + torch.arange(
|
|
num_draft_tokens, device=device
|
|
)
|
|
draft_indices = draft_indices.clamp(min=0, max=max_seq_len - 1)
|
|
|
|
# Extract draft tokens; gather always runs.
|
|
draft_tokens = torch.gather(token_ids, 1, draft_indices)
|
|
|
|
# Mask positions beyond available tokens.
|
|
position_indices = torch.arange(num_draft_tokens, device=device).unsqueeze(0)
|
|
valid_positions = position_indices < tokens_available.unsqueeze(1)
|
|
|
|
draft_tokens = torch.where(
|
|
valid_positions,
|
|
draft_tokens,
|
|
torch.full_like(draft_tokens, -1),
|
|
)
|
|
|
|
# If no match, mask all positions.
|
|
draft_tokens = torch.where(
|
|
has_any_match.unsqueeze(1),
|
|
draft_tokens,
|
|
torch.full_like(draft_tokens, -1),
|
|
)
|
|
|
|
return draft_tokens
|
|
|
|
def forward(
|
|
self,
|
|
num_tokens_no_spec: torch.Tensor,
|
|
token_ids_gpu: torch.Tensor,
|
|
combined_mask: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Forward pass for N-gram proposal using GPU tensor operations.
|
|
|
|
Args:
|
|
num_tokens_no_spec: Number of tokens for each sequence [batch_size]
|
|
token_ids_gpu: Token IDs [batch_size, max_len]
|
|
combined_mask: Whether each sequence is valid for spec decode [batch_size]
|
|
|
|
Returns:
|
|
draft_tokens: [batch_size, k] on GPU
|
|
num_valid_draft_tokens: [batch_size] int32 on GPU, count of
|
|
leading valid (non -1) tokens per request.
|
|
"""
|
|
|
|
device = token_ids_gpu.device
|
|
|
|
# Infer batch size to preserve dynamic shape.
|
|
actual_batch_size = token_ids_gpu.shape[0]
|
|
|
|
# Allocate in forward so torch.compile can optimize.
|
|
# NOTE(patchy): Do NOT pre-allocate this as a buffer
|
|
# it breaks torch.compile
|
|
draft_tokens = torch.full(
|
|
(actual_batch_size, self.k), -1, dtype=torch.int32, device=device
|
|
)
|
|
|
|
results = self._find_first_and_extract_all_n_parallel(
|
|
token_ids_gpu,
|
|
num_tokens_no_spec,
|
|
min_ngram_len=self.min_n,
|
|
max_ngram_len=self.max_n,
|
|
num_draft_tokens=self.k,
|
|
)
|
|
|
|
draft_tokens = torch.where(combined_mask.unsqueeze(1), results, -1)
|
|
|
|
# Count leading contiguous valid (non -1) tokens per request.
|
|
is_valid = draft_tokens != -1 # [batch, k]
|
|
cum_valid = is_valid.int().cumsum(dim=1) # [batch, k]
|
|
positions = torch.arange(1, self.k + 1, device=device).unsqueeze(0)
|
|
num_valid_draft_tokens = (cum_valid == positions).int().sum(dim=1)
|
|
|
|
return draft_tokens, num_valid_draft_tokens
|
|
|
|
def load_model(self, *args, **kwargs):
|
|
"""No model to load for N-gram proposer."""
|
|
pass
|
|
|
|
|
|
class NgramProposerGPU:
|
|
def __init__(self, vllm_config: VllmConfig, device: torch.device, runner=None):
|
|
assert vllm_config.speculative_config is not None
|
|
assert vllm_config.speculative_config.prompt_lookup_min is not None
|
|
assert vllm_config.speculative_config.prompt_lookup_max is not None
|
|
|
|
compilation_config = CompilationConfig(
|
|
mode=CompilationMode.VLLM_COMPILE,
|
|
custom_ops=["none"],
|
|
splitting_ops=[],
|
|
compile_sizes=[],
|
|
inductor_compile_config={
|
|
"enable_auto_functionalized_v2": False,
|
|
"max_autotune": True,
|
|
"aggressive_fusion": True,
|
|
"triton.autotune_pointwise": True,
|
|
"coordinate_descent_tuning": True,
|
|
"use_mixed_mm": False,
|
|
},
|
|
cudagraph_mode=CUDAGraphMode.NONE,
|
|
)
|
|
model_config = vllm_config.model_config
|
|
speculative_config = vllm_config.speculative_config
|
|
scheduler_config = vllm_config.scheduler_config
|
|
|
|
self.vllm_config = VllmConfig(
|
|
compilation_config=compilation_config,
|
|
model_config=model_config,
|
|
speculative_config=speculative_config,
|
|
scheduler_config=scheduler_config,
|
|
)
|
|
|
|
self.min_n = vllm_config.speculative_config.prompt_lookup_min
|
|
self.max_n = vllm_config.speculative_config.prompt_lookup_max
|
|
self.k = vllm_config.speculative_config.num_speculative_tokens
|
|
self.max_model_len = vllm_config.model_config.max_model_len
|
|
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
|
self.device = device
|
|
|
|
self.kernel = NgramGPUKernel(
|
|
vllm_config=self.vllm_config, prefix="ngram_gpu_kernel", device=device
|
|
)
|
|
self.kernel.to(device)
|
|
self.kernel.eval()
|
|
|
|
self._dummy_run()
|
|
|
|
def _dummy_run(self):
|
|
token_ids, num_tokens, sampled_flags, valid_mask = self._generate_dummy_data(
|
|
batch_size=self.max_num_seqs,
|
|
max_seq_len=self.max_model_len,
|
|
pattern_len=self.k,
|
|
device=self.device,
|
|
)
|
|
|
|
combined_mask = sampled_flags & valid_mask & (num_tokens >= self.min_n)
|
|
|
|
for _ in range(3):
|
|
with set_forward_context(None, self.vllm_config):
|
|
_, _ = self.kernel(num_tokens, token_ids, combined_mask)
|
|
|
|
def _generate_dummy_data(
|
|
self,
|
|
batch_size: int,
|
|
max_seq_len: int,
|
|
pattern_len: int,
|
|
device: str = "cuda",
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Generate random test data with n-gram repetitions.
|
|
|
|
Args:
|
|
batch_size: Number of sequences in the batch
|
|
max_seq_len: Maximum sequence length
|
|
pattern_len: Length of patterns to inject for matching
|
|
device: Device to place tensors on
|
|
|
|
Returns:
|
|
token_ids: [batch_size, max_seq_len] tensor
|
|
num_tokens: [batch_size] tensor
|
|
sampled_flags: [batch_size] bool tensor
|
|
valid_mask: [batch_size] bool tensor
|
|
"""
|
|
token_ids = torch.zeros(
|
|
batch_size,
|
|
max_seq_len,
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
num_tokens = torch.randint(
|
|
pattern_len, max_seq_len, (batch_size,), dtype=torch.int32, device=device
|
|
)
|
|
|
|
sampled_flags = torch.ones(batch_size, dtype=torch.bool, device=device)
|
|
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=device)
|
|
|
|
return token_ids, num_tokens, sampled_flags, valid_mask
|
|
|
|
def propose(
|
|
self,
|
|
num_tokens_no_spec: torch.Tensor, # [batch_size]
|
|
token_ids_gpu: torch.Tensor, # [batch_size, max_len]
|
|
valid_sampled_token_ids_gpu: torch.Tensor, # [batch_size, num_spec_tokens + 1]
|
|
valid_sampled_tokens_count: torch.Tensor, # [batch_size]
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Propose draft tokens using GPU-accelerated n-gram matching.
|
|
|
|
Scatter sampled tokens into `token_ids_gpu`, compute temporary
|
|
updated lengths, then run the kernel.
|
|
|
|
Args:
|
|
num_tokens_no_spec: Number of tokens per sequence (read-only)
|
|
token_ids_gpu: Token IDs tensor (modified in-place with new tokens)
|
|
valid_sampled_token_ids_gpu: Newly sampled tokens to scatter
|
|
valid_sampled_tokens_count: Count of valid tokens per sequence
|
|
|
|
Returns:
|
|
draft_tokens: Proposed draft token IDs [batch_size, k]
|
|
num_valid_draft_tokens: Count of leading valid draft tokens
|
|
per request [batch_size]
|
|
"""
|
|
assert token_ids_gpu.device == self.device
|
|
assert num_tokens_no_spec.device == self.device
|
|
|
|
batch_size = num_tokens_no_spec.shape[0]
|
|
max_seq_len = token_ids_gpu.shape[1]
|
|
max_new_tokens = valid_sampled_token_ids_gpu.shape[1] # num_spec_tokens + 1
|
|
|
|
# Scatter newly sampled tokens into token_ids_gpu.
|
|
offsets = torch.arange(max_new_tokens, device=self.device)
|
|
write_positions = num_tokens_no_spec.unsqueeze(1) + offsets.unsqueeze(0)
|
|
valid_write_mask = offsets.unsqueeze(0) < valid_sampled_tokens_count.unsqueeze(
|
|
1
|
|
)
|
|
in_bounds = write_positions < max_seq_len
|
|
scatter_mask = (
|
|
valid_write_mask & (valid_sampled_token_ids_gpu != -1) & in_bounds
|
|
)
|
|
|
|
write_positions_long = write_positions.clamp(max=max_seq_len - 1).long()
|
|
existing_values = token_ids_gpu.gather(1, write_positions_long)
|
|
|
|
tokens_cast = valid_sampled_token_ids_gpu.to(token_ids_gpu.dtype)
|
|
tokens_to_scatter = torch.where(
|
|
scatter_mask,
|
|
tokens_cast,
|
|
existing_values,
|
|
)
|
|
token_ids_gpu.scatter_(1, write_positions_long, tokens_to_scatter)
|
|
|
|
num_tokens_tmp = num_tokens_no_spec + valid_sampled_tokens_count
|
|
|
|
# Compute validity masks.
|
|
sampled_flags = valid_sampled_tokens_count > 0
|
|
valid_mask = torch.ones(batch_size, dtype=torch.bool, device=self.device)
|
|
|
|
with set_forward_context(None, self.vllm_config):
|
|
combined_mask = sampled_flags & valid_mask & (num_tokens_tmp >= self.min_n)
|
|
|
|
with record_function_or_nullcontext("ngram_proposer_gpu: kernel"):
|
|
draft_tokens, num_valid_draft_tokens = self.kernel(
|
|
num_tokens_tmp,
|
|
token_ids_gpu,
|
|
combined_mask,
|
|
)
|
|
|
|
return draft_tokens, num_valid_draft_tokens
|
|
|
|
def update_token_ids_ngram(
|
|
self,
|
|
sampled_token_ids: torch.Tensor | list[list[int]],
|
|
gpu_input_batch: InputBatch,
|
|
token_ids_gpu: torch.Tensor,
|
|
num_tokens_no_spec: torch.Tensor,
|
|
discard_request_mask: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Prepare speculative decoding inputs on device:
|
|
compute next token ids and valid counts, honoring discarded requests
|
|
and rejected tokens, without CPU-GPU sync.
|
|
"""
|
|
num_reqs = gpu_input_batch.num_reqs
|
|
|
|
if isinstance(sampled_token_ids, list):
|
|
# When disable_padded_drafter_batch=True, sampled_token_ids is
|
|
# an irregular list[list[int]] where sublists may have different
|
|
# lengths (including empty lists for discarded requests).
|
|
# Pad all sublists to the same length with -1 before converting
|
|
# to tensor.
|
|
max_len = max(
|
|
(len(sublist) for sublist in sampled_token_ids),
|
|
default=0,
|
|
)
|
|
# Ensure at least length 1 for tensor creation
|
|
max_len = max(max_len, 1)
|
|
padded_list = [
|
|
sublist + [-1] * (max_len - len(sublist))
|
|
for sublist in sampled_token_ids
|
|
]
|
|
sampled_token_ids = torch.tensor(
|
|
padded_list, dtype=torch.int32, device=self.device
|
|
)
|
|
assert isinstance(sampled_token_ids, torch.Tensor), (
|
|
"sampled_token_ids should be a torch.Tensor for ngram_gpu"
|
|
)
|
|
|
|
# Backup last valid token before speculative tokens.
|
|
backup_indices = (num_tokens_no_spec[:num_reqs] - 1).clamp(min=0).long()
|
|
backup_next_token_ids = torch.gather(
|
|
token_ids_gpu[:num_reqs], dim=1, index=backup_indices.unsqueeze(1)
|
|
).squeeze(1)
|
|
|
|
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
|
# Invalidate sampled tokens for discarded requests.
|
|
discard_mask_expanded = discard_request_mask[:num_reqs].unsqueeze(1)
|
|
valid_sampled_token_ids_gpu.masked_fill_(discard_mask_expanded, -1)
|
|
|
|
# Mask valid tokens within each request.
|
|
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
|
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
|
)
|
|
|
|
# Count valid tokens per request.
|
|
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
|
|
|
# Rightmost valid index per row.
|
|
last_valid_indices = valid_sampled_tokens_count - 1
|
|
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
|
|
|
# Last valid token from each row; undefined if none.
|
|
selected_tokens = torch.gather(
|
|
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
|
|
).squeeze(1)
|
|
|
|
# Use last token if valid; otherwise fallback to backup.
|
|
next_token_ids = torch.where(
|
|
last_valid_indices != -1,
|
|
selected_tokens,
|
|
backup_next_token_ids,
|
|
)
|
|
|
|
return next_token_ids, valid_sampled_tokens_count, valid_sampled_token_ids_gpu
|
|
|
|
def load_model(self, *args, **kwargs):
|
|
self.kernel.load_model(*args, **kwargs)
|
|
|
|
|
|
def update_scheduler_for_invalid_drafts(
|
|
num_valid_draft_tokens_event: torch.cuda.Event,
|
|
num_valid_draft_tokens_cpu: torch.Tensor,
|
|
scheduler_output: "SchedulerOutput",
|
|
req_id_to_index: dict[str, int],
|
|
) -> None:
|
|
"""Trim invalid speculative slots using per-request valid draft counts.
|
|
|
|
Args:
|
|
num_valid_draft_tokens_event: Event for async D2H completion.
|
|
num_valid_draft_tokens_cpu: CPU buffer of valid draft counts.
|
|
scheduler_output: Scheduler metadata to update in-place.
|
|
req_id_to_index: Request-id to batch-index mapping.
|
|
"""
|
|
req_data = scheduler_output.scheduled_cached_reqs
|
|
num_valid_draft_tokens_event.synchronize()
|
|
|
|
for req_id in req_data.req_ids:
|
|
req_index = req_id_to_index.get(req_id)
|
|
if req_index is None:
|
|
continue
|
|
|
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(req_id)
|
|
if spec_token_ids is None:
|
|
continue
|
|
|
|
scheduled_k = len(spec_token_ids)
|
|
|
|
valid_k = int(num_valid_draft_tokens_cpu[req_index].item())
|
|
valid_k = max(0, min(valid_k, scheduled_k))
|
|
|
|
tokens_to_trim = scheduled_k - valid_k
|
|
scheduler_output.total_num_scheduled_tokens -= tokens_to_trim
|
|
scheduler_output.num_scheduled_tokens[req_id] -= tokens_to_trim
|
|
|
|
if valid_k == 0:
|
|
scheduler_output.scheduled_spec_decode_tokens.pop(req_id, None)
|
|
else:
|
|
scheduler_output.scheduled_spec_decode_tokens[req_id] = spec_token_ids[
|
|
:valid_k
|
|
]
|
|
|
|
|
|
def update_ngram_gpu_tensors_incremental(
|
|
input_batch: InputBatch,
|
|
token_ids_gpu_tensor: torch.Tensor,
|
|
num_tokens_no_spec_gpu: torch.Tensor,
|
|
new_reqs: list[CachedRequestState],
|
|
device: torch.device,
|
|
_pinned_idx_buf: torch.Tensor,
|
|
_pinned_val_buf: torch.Tensor,
|
|
) -> None:
|
|
"""Incrementally update token_ids_gpu_tensor and num_tokens_no_spec_gpu
|
|
for ngram GPU proposer.
|
|
"""
|
|
prev_req_id_to_index = input_batch.prev_req_id_to_index
|
|
curr_req_id_to_index = input_batch.req_id_to_index
|
|
|
|
if not curr_req_id_to_index:
|
|
return
|
|
|
|
active_indices = list(curr_req_id_to_index.values())
|
|
n_active = len(active_indices)
|
|
|
|
# Use resident pinned buffers to avoid per-call allocation.
|
|
active_idx_cpu = _pinned_idx_buf[:n_active]
|
|
active_idx_cpu.copy_(torch.as_tensor(active_indices, dtype=torch.long))
|
|
|
|
active_idx_gpu = active_idx_cpu.to(device=device, non_blocking=True)
|
|
|
|
new_req_ids = {req.req_id for req in new_reqs}
|
|
|
|
# First run, no previous state.
|
|
if prev_req_id_to_index is None:
|
|
for idx in active_indices:
|
|
num_tokens = input_batch.num_tokens_no_spec[idx]
|
|
if num_tokens > 0:
|
|
token_ids_gpu_tensor[idx, :num_tokens].copy_(
|
|
input_batch.token_ids_cpu_tensor[idx, :num_tokens],
|
|
non_blocking=True,
|
|
)
|
|
|
|
_sync_num_tokens(
|
|
input_batch,
|
|
num_tokens_no_spec_gpu,
|
|
active_idx_cpu,
|
|
active_idx_gpu,
|
|
n_active,
|
|
device,
|
|
_pinned_val_buf,
|
|
)
|
|
return
|
|
|
|
# Detect index changes for reorder.
|
|
reorder_src: list[int] = []
|
|
reorder_dst: list[int] = []
|
|
|
|
for req_id, curr_idx in curr_req_id_to_index.items():
|
|
if req_id in new_req_ids:
|
|
continue
|
|
prev_idx = prev_req_id_to_index.get(req_id)
|
|
if prev_idx is not None and prev_idx != curr_idx:
|
|
reorder_src.append(prev_idx)
|
|
reorder_dst.append(curr_idx)
|
|
|
|
if reorder_src:
|
|
src_tensor = torch.tensor(reorder_src, dtype=torch.long, device=device)
|
|
dst_tensor = torch.tensor(reorder_dst, dtype=torch.long, device=device)
|
|
|
|
temp_token_ids = token_ids_gpu_tensor[src_tensor].clone()
|
|
temp_num_tokens = num_tokens_no_spec_gpu[src_tensor].clone()
|
|
|
|
token_ids_gpu_tensor[dst_tensor] = temp_token_ids
|
|
num_tokens_no_spec_gpu[dst_tensor] = temp_num_tokens
|
|
|
|
# Full copy for new/resumed requests.
|
|
for req_state in new_reqs:
|
|
new_req_idx = curr_req_id_to_index.get(req_state.req_id)
|
|
if new_req_idx is None:
|
|
continue
|
|
|
|
num_tokens = input_batch.num_tokens_no_spec[new_req_idx]
|
|
if num_tokens > 0:
|
|
token_ids_gpu_tensor[new_req_idx, :num_tokens].copy_(
|
|
input_batch.token_ids_cpu_tensor[new_req_idx, :num_tokens],
|
|
non_blocking=True,
|
|
)
|
|
|
|
# Always batch-sync sequence lengths from CPU for ALL active requests.
|
|
_sync_num_tokens(
|
|
input_batch,
|
|
num_tokens_no_spec_gpu,
|
|
active_idx_cpu,
|
|
active_idx_gpu,
|
|
n_active,
|
|
device,
|
|
_pinned_val_buf,
|
|
)
|
|
|
|
|
|
def _sync_num_tokens(
|
|
input_batch: InputBatch,
|
|
num_tokens_no_spec_gpu: torch.Tensor,
|
|
active_idx_cpu: torch.Tensor,
|
|
active_idx_gpu: torch.Tensor,
|
|
n_active: int,
|
|
device: torch.device,
|
|
_pinned_val_buf: torch.Tensor,
|
|
) -> None:
|
|
"""Batch-sync GPU sequence lengths from CPU source of truth.
|
|
|
|
Inputs:
|
|
input_batch: Batch container with CPU length tensor.
|
|
num_tokens_no_spec_gpu: Destination GPU length tensor.
|
|
active_idx_cpu: Active request indices on CPU.
|
|
active_idx_gpu: Active request indices on GPU.
|
|
n_active: Number of active requests.
|
|
device: Target CUDA device.
|
|
_pinned_val_buf: Resident pinned int32 staging buffer.
|
|
Outputs:
|
|
None (updates num_tokens_no_spec_gpu in-place).
|
|
"""
|
|
src_cpu = input_batch.num_tokens_no_spec_cpu_tensor
|
|
vals = _pinned_val_buf[:n_active]
|
|
vals.copy_(src_cpu.index_select(0, active_idx_cpu))
|
|
|
|
num_tokens_no_spec_gpu.index_copy_(
|
|
0,
|
|
active_idx_gpu,
|
|
vals.to(device=device, non_blocking=True),
|
|
)
|
|
|
|
|
|
def copy_num_valid_draft_tokens(
|
|
num_valid_draft_tokens_cpu: torch.Tensor,
|
|
num_valid_draft_tokens_copy_stream: torch.cuda.Stream,
|
|
num_valid_draft_tokens_event: torch.cuda.Event,
|
|
num_valid_draft_tokens: torch.Tensor | None,
|
|
batch_size: int,
|
|
) -> None:
|
|
"""
|
|
Async D2H copy of per-request valid draft counts.
|
|
"""
|
|
if num_valid_draft_tokens is None:
|
|
return
|
|
|
|
num_reqs_to_copy = min(batch_size, num_valid_draft_tokens.shape[0])
|
|
if num_reqs_to_copy <= 0:
|
|
return
|
|
|
|
default_stream = torch.cuda.current_stream()
|
|
with torch.cuda.stream(num_valid_draft_tokens_copy_stream):
|
|
num_valid_draft_tokens_copy_stream.wait_stream(default_stream)
|
|
num_valid_draft_tokens_cpu[:num_reqs_to_copy].copy_(
|
|
num_valid_draft_tokens[:num_reqs_to_copy], non_blocking=True
|
|
)
|
|
num_valid_draft_tokens_event.record()
|