[Async][Spec Decoding] Zero-bubble async scheduling + spec decoding (#32951)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Co-authored-by: zhrrr <43847754+izhuhaoran@users.noreply.github.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Matthew Bonanni
2026-03-23 15:37:22 -04:00
committed by GitHub
parent ffb5b32b5f
commit fafe76b4af
9 changed files with 488 additions and 209 deletions

View File

@@ -177,7 +177,7 @@ def test_prepare_next_token_ids():
next_token_ids_from_padded, valid_sampled_tokens_count = (
proposer.prepare_next_token_ids_padded(
common_attn_metadata,
common_attn_metadata.seq_lens_cpu,
sampled_token_ids_tensor,
mock_requests,
mock_input_batch,

View File

@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded():
)
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
common_attn_metadata,
common_attn_metadata.seq_lens_cpu,
sampled_token_ids,
mock_requests,
mock_input_batch,

View File

@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc]
else:
self.parallel_config.disable_nccl_for_dp_synchronization = False
if (
self.speculative_config is not None
and self.scheduler_config.async_scheduling
and self.model_config is not None
and not self.model_config.disable_cascade_attn
):
logger.warning_once(
"Disabling cascade attention (not yet compatible with "
"async speculative decoding).",
scope="local",
)
self.model_config.disable_cascade_attn = True
if (
self.model_config is not None
and self.model_config.multimodal_config is not None

View File

@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer:
self.method = self.speculative_config.method
self.pass_hidden_states_to_model = pass_hidden_states_to_model
self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer:
)
)
assert self.runner is not None
per_layer_attn_metadata: dict[str, object] = {}
for attn_group in self.draft_attn_groups:
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer:
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
seq_lens_cpu: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer:
"""
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
for i in range(num_reqs)
],
dtype=np.int32,
@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer:
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
max_seq_len=common_attn_metadata.max_seq_len,
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
causal=True,

View File

@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer:
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
seq_lens: torch.Tensor,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer:
device = sampled_token_ids.device
# Compute backup tokens for discarded / invalid requests
seq_lens_list = seq_lens[:num_reqs].tolist()
backup_tokens_gpu = torch.tensor(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
for i in range(num_reqs)
],
dtype=torch.int32,

View File

@@ -3,6 +3,7 @@
import torch
from vllm.config import VllmConfig, replace
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel(
out_idx,
mask=is_new_token_region & in_bounds,
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def update_num_computed_tokens_for_batch_change(
num_computed_tokens: torch.Tensor,
num_accepted_tokens: torch.Tensor,
prev_positions: torch.Tensor,
valid_sampled_token_count: torch.Tensor,
prev_num_draft_tokens: torch.Tensor,
cpu_num_computed_tokens: torch.Tensor,
) -> None:
"""Correct num_computed_tokens for async spec decode drift.
Requests that had drafts: corrected = prev_gpu + valid_count.
New requests or non-draft (e.g. prefills): use CPU value directly.
"""
# Clamp because prev_positions can be -1 for new requests
gather_indices = prev_positions.clamp(min=0)
valid_counts = valid_sampled_token_count[gather_indices]
prev_computed = num_computed_tokens[gather_indices]
prev_drafts = prev_num_draft_tokens[gather_indices]
participating = (prev_positions >= 0) & (prev_drafts > 0)
corrected = prev_computed + valid_counts.int()
n = prev_positions.shape[0]
num_computed_tokens[:n].copy_(
torch.where(participating, corrected, cpu_num_computed_tokens)
)
num_accepted_tokens.copy_(
torch.where(participating, valid_counts, num_accepted_tokens)
)

View File

@@ -6,7 +6,9 @@ import torch
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.cp_utils import get_total_cp_world_size
@@ -131,71 +133,33 @@ class BlockTable:
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
num_tokens = positions.shape[0]
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
mask, slot_mapping, -1
)
else:
block_table_indices = (
req_indices * self.max_num_blocks_per_req + positions // self.block_size
)
block_numbers = self.block_table.np.ravel()[block_table_indices]
block_offsets = positions % self.block_size
np.add(
block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping.np[: req_indices.shape[0]],
)
_compute_slot_mapping_kernel[(num_reqs + 1,)](
num_tokens,
self.max_num_batched_tokens,
query_start_loc,
positions,
self.block_table.gpu,
self.block_table.gpu.stride(0),
self.block_size,
self.slot_mapping.gpu,
TOTAL_CP_WORLD_SIZE=total_cp_world_size,
TOTAL_CP_RANK=total_cp_rank,
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
PAD_ID=PAD_SLOT_ID,
BLOCK_SIZE=1024,
)
def commit_block_table(self, num_reqs: int) -> None:
self.block_table.copy_to_gpu(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
self.slot_mapping.copy_to_gpu(num_tokens)
def clear(self) -> None:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
@@ -320,19 +284,18 @@ class MultiGroupBlockTable:
block_table.swap_row(src, tgt)
def compute_slot_mapping(
self, req_indices: np.ndarray, positions: np.ndarray
self,
num_reqs: int,
query_start_loc: torch.Tensor,
positions: torch.Tensor,
) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit_block_table(num_reqs)
def commit_slot_mapping(self, num_tokens: int) -> None:
for block_table in self.block_tables:
block_table.commit_slot_mapping(num_tokens)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
@@ -340,3 +303,61 @@ class MultiGroupBlockTable:
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]
@triton.jit
def _compute_slot_mapping_kernel(
num_tokens,
max_num_tokens,
query_start_loc_ptr, # [num_reqs + 1], int32
positions_ptr, # [num_tokens], int64
block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
block_table_stride, # max_num_blocks_per_req
block_size,
slot_mapping_ptr, # [max_num_tokens], int64
TOTAL_CP_WORLD_SIZE: tl.constexpr,
TOTAL_CP_RANK: tl.constexpr,
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
PAD_ID: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == tl.num_programs(0) - 1:
# Pad remaining slots for CUDA graph compatibility.
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(
slot_mapping_ptr + offsets,
PAD_ID,
mask=offsets < max_num_tokens,
)
return
start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
row_offset = req_idx * block_table_stride
for i in range(start_idx, end_idx, BLOCK_SIZE):
offsets = i + tl.arange(0, BLOCK_SIZE)
mask = offsets < end_idx
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
block_indices = pos // virtual_block_size
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
tl.int64
)
virtual_block_offsets = pos - block_indices * virtual_block_size
is_local = (
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
local_block_offsets = (
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
)
slot_ids = block_numbers * block_size + local_block_offsets
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)

View File

@@ -219,7 +219,7 @@ class InputBatch:
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones(
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
)
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
@@ -989,13 +989,15 @@ class InputBatch:
continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
# Also account for case where there may be a smaller number of
# output placeholders (tokens can be discarded after a kv-load failure).
# output placeholders (tokens can be discarded after kv-load
# failure) or a larger number (async spec decode adds optimistic
# placeholders that may exceed the actual acceptance count).
first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:]
end_index = first_placeholder + num_to_replace
req_output_token_ids[first_placeholder:end_index] = new_ids
req_output_token_ids[first_placeholder:] = new_ids
# ^ Implicitly resizes to (first_placeholder + num_to_replace)
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
"""

View File

@@ -7,7 +7,7 @@ import itertools
import threading
import time
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass, replace
@@ -172,6 +172,7 @@ from vllm.v1.spec_decode.ngram_proposer_gpu import (
update_scheduler_for_invalid_drafts,
)
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change
from vllm.v1.structured_output.utils import apply_grammar_bitmask
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker import mamba_utils
@@ -570,6 +571,7 @@ class GPUModelRunner(
self.rejection_sampler = RejectionSampler(self.sampler)
self.num_spec_tokens = 0
self.valid_sampled_token_count_gpu: torch.Tensor | None = None
if self.speculative_config:
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
draft_config = self.speculative_config.draft_model_config
@@ -577,6 +579,9 @@ class GPUModelRunner(
self.effective_drafter_max_model_len = draft_config.max_model_len
else:
self.effective_drafter_max_model_len = self.max_model_len
self.use_async_spec_decode = (
self.use_async_scheduling and self.num_spec_tokens > 0
)
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@@ -659,11 +664,31 @@ class GPUModelRunner(
# Persistent buffers for CUDA graphs.
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
self.positions = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=self.device
)
self.query_start_loc = self._make_buffer(
self.max_num_reqs + 1, dtype=torch.int32
)
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.seq_lens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.optimistic_seq_lens_cpu = torch.zeros(
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
)
self.num_computed_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=self.device
)
self.prev_num_draft_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
# Maps current batch position -> previous batch position (-1 for new reqs)
self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64)
self.num_scheduled_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int32
)
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
if self.dcp_world_size > 1:
self.dcp_local_seq_lens = self._make_buffer(
@@ -683,7 +708,7 @@ class GPUModelRunner(
self.max_num_reqs, dtype=torch.int32
)
self.num_accepted_tokens = self._make_buffer(
self.max_num_reqs, dtype=torch.int64
self.max_num_reqs, dtype=torch.int32
)
# Only relevant for multimodal models
@@ -722,12 +747,14 @@ class GPUModelRunner(
# None in the first PP rank. The rest are set after load_model.
self.intermediate_tensors: IntermediateTensors | None = None
# OPTIMIZATION: Cache the tensors rather than creating them every step.
# Keep in int64 to avoid overflow with long context
self.arange_np = np.arange(
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
dtype=np.int64,
)
# OPTIMIZATION: Cache the arange tensors rather than creating them
# every step. Keep in int64 to avoid overflow with long context.
# - arange_np: immutable [0, 1, 2, ...] used as source for batched computation
# - query_pos: CpuGpuBuffer for the computed batched arange result
arange_size = max(self.max_num_reqs + 1, self.max_num_tokens)
self.arange_np = np.arange(arange_size, dtype=np.int64)
self.query_pos = self._make_buffer(arange_size, dtype=torch.int64)
self._arange_scratch = np.empty(arange_size, dtype=np.int64)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
@@ -812,7 +839,7 @@ class GPUModelRunner(
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
self.valid_sampled_token_count_cpu = torch.empty(
self.max_num_reqs,
dtype=torch.int64,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
@@ -903,13 +930,13 @@ class GPUModelRunner(
return self.mrope_positions.gpu[:, :num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, :num_tokens]
return self.positions.gpu[:num_tokens]
return self.positions[:num_tokens]
else:
if self.uses_mrope:
return self.mrope_positions.gpu[:, num_tokens]
if self.uses_xdrope_dim > 0:
return self.xdrope_positions.gpu[:, num_tokens]
return self.positions.gpu[num_tokens]
return self.positions[num_tokens]
def _make_buffer(
self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
@@ -953,7 +980,7 @@ class GPUModelRunner(
if len(token_type_id_requests) == 0:
return model_kwargs
seq_lens = self.seq_lens.gpu[:num_reqs]
seq_lens = self.seq_lens[:num_reqs]
token_type_ids = []
for i in range(num_reqs):
@@ -1021,7 +1048,7 @@ class GPUModelRunner(
def _sync_device(self) -> None:
torch.accelerator.synchronize()
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -1086,6 +1113,8 @@ class GPUModelRunner(
ngram_gpu_new_reqs: list[CachedRequestState] = []
reqs_to_add: list[CachedRequestState] = []
deferred_spec_decode_corrections = []
# Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id
@@ -1172,10 +1201,8 @@ class GPUModelRunner(
scheduler_output,
self.input_batch.req_id_to_index,
)
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
valid_sampled_token_count = self._get_valid_sampled_token_count()
if self.use_async_spec_decode:
self.prev_num_draft_tokens.np.fill(0)
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
@@ -1202,15 +1229,30 @@ class GPUModelRunner(
if req_index is None:
req_state.prev_num_draft_len = 0
else:
assert self.input_batch.prev_req_id_to_index is not None
prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
num_accepted = valid_sampled_token_count[prev_req_index] - 1
num_rejected = req_state.prev_num_draft_len - num_accepted
num_computed_tokens -= num_rejected
req_state.output_token_ids.extend([-1] * num_accepted)
# Optimistically assume all accepted; queue up a correction
# to be called after the model forward to preserve async
# scheduling. Corrected on GPU in _prepare_inputs.
optimistic_num_accepted = req_state.prev_num_draft_len
req_state.output_token_ids.extend([-1] * optimistic_num_accepted)
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
deferred_spec_decode_corrections.append(
(req_id, optimistic_num_accepted, req_state)
)
prev_req_index = (
self.input_batch.prev_req_id_to_index.get(req_id)
if self.input_batch.prev_req_id_to_index
else None
)
if prev_req_index is not None:
self.prev_num_draft_tokens.np[prev_req_index] = (
optimistic_num_accepted
)
if is_ngram_gpu and optimistic_num_accepted > 0:
self.input_batch.num_tokens_no_spec[req_index] += (
optimistic_num_accepted
)
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
@@ -1238,7 +1280,8 @@ class GPUModelRunner(
)
elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
# failure, or output_token_ids was inflated by the optimistic
# extend above (async spec decode). Align the cached state.
del req_state.output_token_ids[num_output_tokens:]
if req_index is not None:
end_idx = (
@@ -1326,6 +1369,40 @@ class GPUModelRunner(
_pinned_val_buf=self._ngram_pinned_val_buf,
)
if deferred_spec_decode_corrections:
def correct_spec_decode_token_counts():
valid_sampled_token_count = self._get_valid_sampled_token_count()
if not valid_sampled_token_count:
return
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
if not prev_req_id_to_index:
return
for (
req_id,
optimistic_num_accepted,
req_state,
) in deferred_spec_decode_corrections:
prev_req_index = prev_req_id_to_index.get(req_id)
if prev_req_index is None:
continue
num_accepted = valid_sampled_token_count[prev_req_index] - 1
correction = optimistic_num_accepted - num_accepted
req_state.num_computed_tokens -= correction
cur_req_index = self.input_batch.req_id_to_index.get(req_id)
if cur_req_index is None:
continue
self.input_batch.num_computed_tokens_cpu[cur_req_index] -= (
correction
)
if is_ngram_gpu and correction > 0:
self.input_batch.num_tokens_no_spec[cur_req_index] -= correction
self.num_tokens_no_spec_gpu[cur_req_index] -= correction
return correct_spec_decode_token_counts
else:
return None
def _update_states_after_model_execute(
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
) -> None:
@@ -1340,6 +1417,9 @@ class GPUModelRunner(
if not self.speculative_config or not self.model_config.is_hybrid:
return
# TODO: Remove .cpu() sync to enable fully async for hybrid model;
# Use num_computed_tokens.gpu instead of req.num_computed_tokens to
# support aligned mamba cache mode.
# Find the number of accepted tokens for each sequence.
num_reqs = output_token_ids.size(0)
self.num_accepted_tokens.gpu[:num_reqs] = (
@@ -1486,12 +1566,14 @@ class GPUModelRunner(
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
arange_out: np.ndarray,
cumsum_dtype: np.dtype | None = None,
) -> tuple[np.ndarray, np.ndarray]:
) -> np.ndarray:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_tokens])
E.g., [2, 5, 3] -> [2, 7, 10], arange written to
arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2].
Equivalent to but faster than:
np.concatenate([np.arange(n) for n in num_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
@@ -1499,13 +1581,33 @@ class GPUModelRunner(
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
np.subtract(
self.arange_np[:total_num_tokens],
cumsums_offsets,
out=arange_out[:total_num_tokens],
)
return cu_num_tokens, arange
return cu_num_tokens
def _compute_prev_positions(self, num_reqs: int) -> None:
"""Build prev_positions mapping: current pos -> previous pos (-1 if new).
Populates self.prev_positions.np[:num_reqs] with the mapping.
"""
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
prev_positions = self.prev_positions.np[:num_reqs]
if not prev_req_id_to_index:
prev_positions.fill(-1)
return
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
prev_positions[i] = prev_req_id_to_index.get(req_id, -1)
def _prepare_input_ids(
self,
scheduler_output: "SchedulerOutput",
num_reqs: int,
total_num_scheduled_tokens: int,
cu_num_tokens: np.ndarray,
) -> None:
@@ -1513,7 +1615,11 @@ class GPUModelRunner(
Carefully handles the `prev_sampled_token_ids` which can be cached
from the previous engine iteration, in which case those tokens on the
GPU need to be copied into the corresponding slots into input_ids."""
GPU need to be copied into the corresponding slots into input_ids.
Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos
(-1 for new requests).
"""
if self.input_batch.prev_sampled_token_ids is None:
# Normal scheduling case
@@ -1526,47 +1632,50 @@ class GPUModelRunner(
# Async scheduling case, where some decode requests from the previous
# iteration won't have entries in input_ids_cpu and need to be copied
# on the GPU from prev_sampled_token_ids.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
assert prev_req_id_to_index is not None
prev_positions = self.prev_positions.np[:num_reqs]
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
sample_flattened_indices: list[int] = []
spec_flattened_indices: list[int] = []
prev_common_req_indices: list[int] = []
prev_draft_token_indices: list[int] = []
indices_match = True
prev_indices: list[int] = []
common_indices_match = True
max_flattened_index = -1
total_num_spec_tokens = 0
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id, cur_index in self.input_batch.req_id_to_index.items():
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
prev_common_req_indices.append(prev_index)
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1, flattened_index + 1)
)
start = prev_index * self.num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start, start + draft_len))
indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
for cur_index in range(num_reqs):
prev_index = prev_positions[cur_index]
if prev_index < 0:
continue
prev_indices.append(prev_index)
req_id = self.input_batch.req_ids[cur_index]
# We need to compute the flattened input_ids index of the
# last token in each common request.
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
total_num_spec_tokens += draft_len
flattened_index = cu_num_tokens[cur_index].item() - 1
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
# sample_flattened_indices = [0, 2, 5]
# spec_flattened_indices = [1, 3, 4, 6, 7]
sample_flattened_indices.append(flattened_index - draft_len)
spec_flattened_indices.extend(
range(flattened_index - draft_len + 1, flattened_index + 1)
)
start = prev_index * self.num_spec_tokens
# prev_draft_token_indices is used to find which draft_tokens_id
# should be copied to input_ids
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
# flatten draft_tokens_id [1,2,3,4,5,6]
# draft_len of each request [1, 2, 1]
# then prev_draft_token_indices is [0, 2, 3, 4]
prev_draft_token_indices.extend(range(start, start + draft_len))
common_indices_match &= prev_index == flattened_index
max_flattened_index = max(max_flattened_index, flattened_index)
num_common_tokens = len(sample_flattened_indices)
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
if num_common_tokens < total_without_spec:
# If not all requests are decodes from the last iteration,
# We need to copy the input_ids_cpu to the GPU first.
# we need to copy the input_ids_cpu to the GPU first.
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
if self.enable_prompt_embeds:
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
@@ -1575,7 +1684,7 @@ class GPUModelRunner(
# No requests in common with the previous iteration
# So input_ids.cpu will have all the input ids.
return
if indices_match and max_flattened_index == (num_common_tokens - 1):
if common_indices_match and max_flattened_index == (num_common_tokens - 1):
# Common-case optimization: the batch is unchanged
# and no reordering happened.
# The indices are both the same permutation of 0..N-1 so
@@ -1592,7 +1701,7 @@ class GPUModelRunner(
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
prev_common_req_indices_tensor = torch.tensor(
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
prev_indices, dtype=torch.int64, pin_memory=self.pin_memory
).to(self.device, non_blocking=True)
self.input_ids.gpu.scatter_(
dim=0,
@@ -1696,15 +1805,15 @@ class GPUModelRunner(
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
# self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens = self._get_cumsum_and_arange(
num_scheduled_tokens, self.query_pos.np
)
# Get positions.
positions_np = self.positions.np[:total_num_scheduled_tokens]
np.add(
self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np,
positions_np = (
self.input_batch.num_computed_tokens_cpu[req_indices]
+ self.query_pos.np[: cu_num_tokens[-1]]
)
# Calculate M-RoPE positions.
@@ -1782,9 +1891,6 @@ class GPUModelRunner(
output_idx += num_sched
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
# Prepare the attention metadata.
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
@@ -1794,12 +1900,21 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
self.seq_lens.np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
# Compute optimistic seq_lens (assumes all draft tokens from previous
# iteration accepted). Store in optimistic_seq_lens_cpu for use by
# _build_attention_metadata (max_seq_len) and discard_request_mask.
# seq_lens (GPU) will be computed later using the same optimistic values.
torch.add(
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
torch.from_numpy(num_scheduled_tokens),
out=self.optimistic_seq_lens_cpu[:num_reqs],
)
# Fill unused with 0 for full cuda graph mode.
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
# Build prev_positions mapping: current pos -> prev pos (-1 if new).
# Used for gathering from previous iteration's GPU tensors.
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
self._compute_prev_positions(num_reqs)
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
@@ -1807,13 +1922,78 @@ class GPUModelRunner(
# Record which requests should not be sampled,
# so that we could clear the sampled tokens before returning
self.discard_request_mask.np[:num_reqs] = (
self.seq_lens.np[:num_reqs] < num_tokens_np
self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np
)
self.discard_request_mask.copy_to_gpu(num_reqs)
# Sync num_accepted_tokens from CPU (set by
# _update_states_after_model_execute for hybrid models).
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
else:
self.num_accepted_tokens.np.fill(1)
self.num_accepted_tokens.gpu.fill_(1)
# Update num_computed_tokens on GPU. In async spec decode,
# CPU values are optimistic (all drafts accepted). The kernel
# corrects on GPU using the previous step's
# valid_sampled_token_count_gpu. Otherwise, just copy from CPU.
if (
self.use_async_spec_decode
and self.valid_sampled_token_count_gpu is not None
and prev_req_id_to_index
):
self.prev_positions.copy_to_gpu(num_reqs)
self.prev_num_draft_tokens.copy_to_gpu()
cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to(
device=self.device, non_blocking=True
)
update_num_computed_tokens_for_batch_change(
self.num_computed_tokens,
self.num_accepted_tokens.gpu[:num_reqs],
self.prev_positions.gpu[:num_reqs],
self.valid_sampled_token_count_gpu,
self.prev_num_draft_tokens.gpu,
cpu_values,
)
else:
self.num_computed_tokens[:num_reqs].copy_(
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
non_blocking=True,
)
self.req_indices.np[:total_num_scheduled_tokens] = req_indices
self.req_indices.copy_to_gpu(total_num_scheduled_tokens)
req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens]
self.query_pos.copy_to_gpu(total_num_scheduled_tokens)
self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens
self.num_scheduled_tokens.copy_to_gpu(num_reqs)
num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs]
self.positions[:total_num_scheduled_tokens] = (
self.num_computed_tokens[req_indices_gpu].to(torch.int64)
+ self.query_pos.gpu[:total_num_scheduled_tokens]
)
self.seq_lens[:num_reqs] = (
self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu
)
self.seq_lens[num_reqs:].fill_(0)
self.input_batch.block_table.compute_slot_mapping(
num_reqs,
self.query_start_loc.gpu[: num_reqs + 1],
self.positions[:total_num_scheduled_tokens],
)
# Copy the tensors to the GPU.
self._prepare_input_ids(
scheduler_output,
num_reqs,
total_num_scheduled_tokens,
cu_num_tokens,
)
@@ -1830,9 +2010,14 @@ class GPUModelRunner(
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
non_blocking=True,
)
else:
# Common case (1D positions)
self.positions.copy_to_gpu(total_num_scheduled_tokens)
if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0):
drift = self.num_computed_tokens[req_indices_gpu].to(
torch.int64
) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to(
device=self.device, dtype=torch.int64, non_blocking=True
)
target = self.mrope_positions if self.uses_mrope else self.xdrope_positions
target.gpu[:, :total_num_scheduled_tokens] += drift
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
@@ -1857,12 +2042,13 @@ class GPUModelRunner(
draft_token_ids,
) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
draft_len = len(draft_token_ids)
num_draft_tokens[req_idx] = draft_len
if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]
):
num_decode_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = draft_len
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens
)
@@ -1924,16 +2110,7 @@ class GPUModelRunner(
# window size when capturing to make sure the correct kernel is selected.
max_seq_len = self.max_model_len
else:
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
if use_spec_decode:
if self.num_accepted_tokens_event is not None:
self.num_accepted_tokens_event.synchronize()
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.np[num_reqs:].fill(1)
self.num_accepted_tokens.copy_to_gpu()
max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item()
kv_cache_groups = self.kv_cache_config.kv_cache_groups
@@ -1963,22 +2140,29 @@ class GPUModelRunner(
attn_gid = self.routed_experts_attn_gid
slot_mapping_attn = slot_mappings[attn_gid]
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
# Compute is_prefilling: True if request is still in prefill phase
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
# distinguish actual decodes from short extends.
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
]
num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[
:num_reqs_padded
]
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded]
# is_prefilling: True if request is still in prefill phase.
# Used by mamba backends to distinguish actual decodes from
# short extends.
is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu
if self.use_async_spec_decode:
# GPU tensors are authoritative in async mode.
seq_lens_cpu = None
num_computed_tokens_cpu = None
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
seq_lens=self.seq_lens.gpu[:num_reqs_padded],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
seq_lens=self.seq_lens[:num_reqs_padded],
_seq_lens_cpu=seq_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
@@ -1992,7 +2176,7 @@ class GPUModelRunner(
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
self.seq_lens.cpu[:num_reqs],
self.optimistic_seq_lens_cpu[:num_reqs],
self.dcp_world_size,
self.dcp_rank,
self.parallel_config.cp_kv_cache_interleave_size,
@@ -2396,33 +2580,34 @@ class GPUModelRunner(
# [4, 1, 3, 1, 2]
num_sampled_tokens = num_draft_tokens + 1
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
num_sampled_tokens, cumsum_dtype=np.int32
# Step 1.
# cu_num_sampled_tokens: [4, 5, 8, 9, 11]
# _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
cu_num_sampled_tokens = self._get_cumsum_and_arange(
num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32
)
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_indices = np.repeat(
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
)
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
logits_indices += arange
logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]]
# Compute the bonus logits indices.
bonus_logits_indices = cu_num_sampled_tokens - 1
# Compute the draft logits indices.
# cu_num_draft_tokens: [3, 3, 5, 5, 6]
# arange: [0, 1, 2, 0, 1, 0]
cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
num_draft_tokens, cumsum_dtype=np.int32
# _arange_scratch[:6]: [0, 1, 2, 0, 1, 0]
cu_num_draft_tokens = self._get_cumsum_and_arange(
num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32
)
# [0, 0, 0, 5, 5, 9]
target_logits_indices = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
)
# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange
target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]]
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
@@ -2924,7 +3109,7 @@ class GPUModelRunner(
)
hidden_states = hidden_states[:num_scheduled_tokens]
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs]
pooling_metadata = self.input_batch.get_pooling_metadata()
pooling_metadata.build_pooling_cursor(
@@ -3083,9 +3268,9 @@ class GPUModelRunner(
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
else:
positions = self.positions.gpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
if num_input_tokens > num_scheduled_tokens:
self.positions.gpu[num_scheduled_tokens:num_input_tokens].zero_()
self.positions[num_scheduled_tokens:num_input_tokens].zero_()
if is_first_rank:
intermediate_tensors = None
@@ -3610,7 +3795,7 @@ class GPUModelRunner(
self.synchronize_input_prep(),
):
# Update persistent batch states.
self._update_states(scheduler_output)
deferred_state_corrections_fn = self._update_states(scheduler_output)
if has_ec_transfer() and not get_ec_transfer().is_consumer:
with self.maybe_get_ec_connector_output(
@@ -3723,6 +3908,12 @@ class GPUModelRunner(
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align":
# preprocess_mamba reads req_state.num_computed_tokens (CPU)
# to decide copy operations, so we must apply deferred
# corrections before it runs.
if deferred_state_corrections_fn:
deferred_state_corrections_fn()
deferred_state_corrections_fn = None
mamba_utils.preprocess_mamba(
scheduler_output,
self.kv_cache_config,
@@ -3734,6 +3925,14 @@ class GPUModelRunner(
self.model.get_mamba_state_copy_func(),
self._get_mamba_copy_bufs(),
)
# preprocess_mamba resets num_accepted_tokens_cpu to 1
# for requests whose state was copied to a new block.
# Re-sync to GPU so the mamba kernel reads from the
# correct initial state slot (init_token_idx = 0).
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
)
self.num_accepted_tokens.copy_to_gpu(num_reqs)
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
@@ -3894,6 +4093,12 @@ class GPUModelRunner(
slot_mappings,
)
self.kv_connector_output = kv_connector_output
# Now the batch has been launched we can wait for corrections from the
# previous model forward without breaking async scheduling.
if deferred_state_corrections_fn:
deferred_state_corrections_fn()
return None
@torch.inference_mode
@@ -3958,6 +4163,7 @@ class GPUModelRunner(
self._draft_token_ids = None
self._draft_token_req_ids = None
self.valid_sampled_token_count_gpu = None
self.input_batch.prev_sampled_token_ids = None
def propose_draft_token_ids(sampled_token_ids):
@@ -4002,7 +4208,7 @@ class GPUModelRunner(
assert spec_decode_common_attn_metadata is not None
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
spec_decode_common_attn_metadata,
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -4237,6 +4443,9 @@ class GPUModelRunner(
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
self.valid_sampled_token_count_event.record()
if self.use_async_spec_decode:
# Stash for GPU-side correction in _prepare_inputs.
self.valid_sampled_token_count_gpu = valid_sampled_tokens_count
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
def _get_valid_sampled_token_count(self) -> list[int]:
@@ -4366,7 +4575,7 @@ class GPUModelRunner(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -4405,7 +4614,7 @@ class GPUModelRunner(
)
next_token_ids, valid_sampled_tokens_count = (
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
self.optimistic_seq_lens_cpu,
sampled_token_ids,
self.requests,
self.input_batch,
@@ -5148,14 +5357,19 @@ class GPUModelRunner(
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment]
seq_lens = torch.tensor( # type: ignore[assignment]
[1] * num_decode_tokens + [num_prefill_tokens + 1],
dtype=torch.int,
)
else:
seq_lens = max_query_len # type: ignore[assignment]
self.seq_lens.np[:num_reqs] = seq_lens
self.seq_lens.np[num_reqs:] = 0
self.seq_lens.copy_to_gpu()
self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens
self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True)
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
cum_num_tokens = self._get_cumsum_and_arange(
num_scheduled_tokens, self.query_pos.np
)
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
self.query_start_loc.copy_to_gpu()
@@ -5201,7 +5415,7 @@ class GPUModelRunner(
elif self.uses_xdrope_dim > 0:
positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
else:
positions = self.positions.gpu[:num_tokens_padded]
positions = self.positions[:num_tokens_padded]
if get_pp_group().is_first_rank:
intermediate_tensors = None