diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 6ac68e055..fb4ea1bce 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -177,7 +177,7 @@ def test_prepare_next_token_ids(): next_token_ids_from_padded, valid_sampled_tokens_count = ( proposer.prepare_next_token_ids_padded( - common_attn_metadata, + common_attn_metadata.seq_lens_cpu, sampled_token_ids_tensor, mock_requests, mock_input_batch, diff --git a/tests/v1/spec_decode/test_extract_hidden_states.py b/tests/v1/spec_decode/test_extract_hidden_states.py index 6f0ac8cae..27b2a53c1 100644 --- a/tests/v1/spec_decode/test_extract_hidden_states.py +++ b/tests/v1/spec_decode/test_extract_hidden_states.py @@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded(): ) next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded( - common_attn_metadata, + common_attn_metadata.seq_lens_cpu, sampled_token_ids, mock_requests, mock_input_batch, diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 8ff8f79b9..014400fa9 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc] else: self.parallel_config.disable_nccl_for_dp_synchronization = False + if ( + self.speculative_config is not None + and self.scheduler_config.async_scheduling + and self.model_config is not None + and not self.model_config.disable_cascade_attn + ): + logger.warning_once( + "Disabling cascade attention (not yet compatible with " + "async speculative decoding).", + scope="local", + ) + self.model_config.disable_cascade_attn = True + if ( self.model_config is not None and self.model_config.multimodal_config is not None diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 445bb403b..4b20413ca 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -71,7 +71,6 @@ class SpecDecodeBaseProposer: self.method = self.speculative_config.method self.pass_hidden_states_to_model = pass_hidden_states_to_model - self.runner = runner self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len @@ -424,8 +423,6 @@ class SpecDecodeBaseProposer: ) ) - assert self.runner is not None - per_layer_attn_metadata: dict[str, object] = {} for attn_group in self.draft_attn_groups: attn_metadata = attn_group.get_metadata_builder().build_for_drafting( @@ -821,7 +818,7 @@ class SpecDecodeBaseProposer: def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + seq_lens_cpu: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -836,11 +833,10 @@ class SpecDecodeBaseProposer: """ # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs + seq_lens_list = seq_lens_cpu[:num_reqs].tolist() self.backup_next_token_ids.np[:num_reqs] = np.array( [ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item() - ) + requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) for i in range(num_reqs) ], dtype=np.int32, @@ -925,7 +921,7 @@ class SpecDecodeBaseProposer: num_reqs=common_attn_metadata.num_reqs, num_actual_tokens=total_num_tokens, max_query_len=new_query_len_per_req.max().item(), - max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), + max_seq_len=common_attn_metadata.max_seq_len, block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens], causal=True, diff --git a/vllm/v1/spec_decode/extract_hidden_states.py b/vllm/v1/spec_decode/extract_hidden_states.py index dd4e47d45..e26fa768a 100644 --- a/vllm/v1/spec_decode/extract_hidden_states.py +++ b/vllm/v1/spec_decode/extract_hidden_states.py @@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer: def prepare_next_token_ids_padded( self, - common_attn_metadata: CommonAttentionMetadata, + seq_lens: torch.Tensor, sampled_token_ids: torch.Tensor, requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, @@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer: device = sampled_token_ids.device # Compute backup tokens for discarded / invalid requests + seq_lens_list = seq_lens[:num_reqs].tolist() backup_tokens_gpu = torch.tensor( [ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item() - ) + requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i]) for i in range(num_reqs) ], dtype=torch.int32, diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index cfc30c3e6..48840967b 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -3,6 +3,7 @@ import torch from vllm.config import VllmConfig, replace +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata, @@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel( out_idx, mask=is_new_token_region & in_bounds, ) + + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def update_num_computed_tokens_for_batch_change( + num_computed_tokens: torch.Tensor, + num_accepted_tokens: torch.Tensor, + prev_positions: torch.Tensor, + valid_sampled_token_count: torch.Tensor, + prev_num_draft_tokens: torch.Tensor, + cpu_num_computed_tokens: torch.Tensor, +) -> None: + """Correct num_computed_tokens for async spec decode drift. + + Requests that had drafts: corrected = prev_gpu + valid_count. + New requests or non-draft (e.g. prefills): use CPU value directly. + """ + # Clamp because prev_positions can be -1 for new requests + gather_indices = prev_positions.clamp(min=0) + + valid_counts = valid_sampled_token_count[gather_indices] + prev_computed = num_computed_tokens[gather_indices] + prev_drafts = prev_num_draft_tokens[gather_indices] + + participating = (prev_positions >= 0) & (prev_drafts > 0) + corrected = prev_computed + valid_counts.int() + + n = prev_positions.shape[0] + num_computed_tokens[:n].copy_( + torch.where(participating, corrected, cpu_num_computed_tokens) + ) + num_accepted_tokens.copy_( + torch.where(participating, valid_counts, num_accepted_tokens) + ) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 591f49761..0f5446b44 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -6,7 +6,9 @@ import torch from vllm.distributed import get_dcp_group, get_pcp_group from vllm.logger import init_logger +from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.cp_utils import get_total_cp_world_size @@ -131,71 +133,33 @@ class BlockTable: self.block_table.np[src_tgt] = self.block_table.np[tgt_src] def compute_slot_mapping( - self, req_indices: np.ndarray, positions: np.ndarray + self, + num_reqs: int, + query_start_loc: torch.Tensor, + positions: torch.Tensor, ) -> None: - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. + num_tokens = positions.shape[0] total_cp_world_size = self.pcp_world_size * self.dcp_world_size total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank - if total_cp_world_size > 1: - # Note(hc): The DCP implement store kvcache with an interleave - # style, the kvcache for the token whose token_idx is i is - # always stored on the GPU whose dcp_rank equals i % cp_world_size: - - # Use a "virtual block" which equals to world_size * block_size - # for block_table_indices calculation. - virtual_block_size = self.block_size * total_cp_world_size - block_table_indices = ( - req_indices * self.max_num_blocks_per_req - + positions // virtual_block_size - ) - - block_numbers = self.block_table.np.ravel()[block_table_indices] - # Use virtual_block_size for mask calculation, which marks local - # tokens. - virtual_block_offsets = positions % virtual_block_size - mask = ( - virtual_block_offsets - // self.cp_kv_cache_interleave_size - % total_cp_world_size - == total_cp_rank - ) - # Calculate local block_offsets - block_offsets = ( - virtual_block_offsets - // (total_cp_world_size * self.cp_kv_cache_interleave_size) - * self.cp_kv_cache_interleave_size - + virtual_block_offsets % self.cp_kv_cache_interleave_size - ) - # Calculate slot_mapping - slot_mapping = block_numbers * self.block_size + block_offsets - # Write final slots, use -1 for not-local - self.slot_mapping.np[: req_indices.shape[0]] = np.where( - mask, slot_mapping, -1 - ) - else: - block_table_indices = ( - req_indices * self.max_num_blocks_per_req + positions // self.block_size - ) - - block_numbers = self.block_table.np.ravel()[block_table_indices] - block_offsets = positions % self.block_size - np.add( - block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping.np[: req_indices.shape[0]], - ) + _compute_slot_mapping_kernel[(num_reqs + 1,)]( + num_tokens, + self.max_num_batched_tokens, + query_start_loc, + positions, + self.block_table.gpu, + self.block_table.gpu.stride(0), + self.block_size, + self.slot_mapping.gpu, + TOTAL_CP_WORLD_SIZE=total_cp_world_size, + TOTAL_CP_RANK=total_cp_rank, + CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size, + PAD_ID=PAD_SLOT_ID, + BLOCK_SIZE=1024, + ) def commit_block_table(self, num_reqs: int) -> None: self.block_table.copy_to_gpu(num_reqs) - def commit_slot_mapping(self, num_tokens: int) -> None: - self.slot_mapping.copy_to_gpu(num_tokens) - def clear(self) -> None: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) @@ -320,19 +284,18 @@ class MultiGroupBlockTable: block_table.swap_row(src, tgt) def compute_slot_mapping( - self, req_indices: np.ndarray, positions: np.ndarray + self, + num_reqs: int, + query_start_loc: torch.Tensor, + positions: torch.Tensor, ) -> None: for block_table in self.block_tables: - block_table.compute_slot_mapping(req_indices, positions) + block_table.compute_slot_mapping(num_reqs, query_start_loc, positions) def commit_block_table(self, num_reqs: int) -> None: for block_table in self.block_tables: block_table.commit_block_table(num_reqs) - def commit_slot_mapping(self, num_tokens: int) -> None: - for block_table in self.block_tables: - block_table.commit_slot_mapping(num_tokens) - def clear(self) -> None: for block_table in self.block_tables: block_table.clear() @@ -340,3 +303,61 @@ class MultiGroupBlockTable: def __getitem__(self, idx: int) -> "BlockTable": """Returns the BlockTable for the i-th KV cache group.""" return self.block_tables[idx] + + +@triton.jit +def _compute_slot_mapping_kernel( + num_tokens, + max_num_tokens, + query_start_loc_ptr, # [num_reqs + 1], int32 + positions_ptr, # [num_tokens], int64 + block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat) + block_table_stride, # max_num_blocks_per_req + block_size, + slot_mapping_ptr, # [max_num_tokens], int64 + TOTAL_CP_WORLD_SIZE: tl.constexpr, + TOTAL_CP_RANK: tl.constexpr, + CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr, + PAD_ID: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + + if req_idx == tl.num_programs(0) - 1: + # Pad remaining slots for CUDA graph compatibility. + for i in range(num_tokens, max_num_tokens, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store( + slot_mapping_ptr + offsets, + PAD_ID, + mask=offsets < max_num_tokens, + ) + return + + start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64) + end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64) + + virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE + row_offset = req_idx * block_table_stride + for i in range(start_idx, end_idx, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + mask = offsets < end_idx + pos = tl.load(positions_ptr + offsets, mask=mask, other=0) + block_indices = pos // virtual_block_size + block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to( + tl.int64 + ) + + virtual_block_offsets = pos - block_indices * virtual_block_size + is_local = ( + virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE + ) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK + local_block_offsets = ( + virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE) + ) * CP_KV_CACHE_INTERLEAVE_SIZE + ( + virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE + ) + + slot_ids = block_numbers * block_size + local_block_offsets + slot_ids = tl.where(is_local, slot_ids, PAD_ID) + tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index fb7795e04..e20d268fe 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -219,7 +219,7 @@ class InputBatch: # Speculative decoding self.num_accepted_tokens_cpu_tensor = torch.ones( - (max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory + (max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory ) self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy() @@ -989,13 +989,15 @@ class InputBatch: continue num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1) # Also account for case where there may be a smaller number of - # output placeholders (tokens can be discarded after a kv-load failure). + # output placeholders (tokens can be discarded after kv-load + # failure) or a larger number (async spec decode adds optimistic + # placeholders that may exceed the actual acceptance count). first_placeholder = req_output_token_ids.index(-1) num_placeholders = len(req_output_token_ids) - first_placeholder num_to_replace = min(num_sampled_ids, num_placeholders) del new_ids[num_to_replace:] - end_index = first_placeholder + num_to_replace - req_output_token_ids[first_placeholder:end_index] = new_ids + req_output_token_ids[first_placeholder:] = new_ids + # ^ Implicitly resizes to (first_placeholder + num_to_replace) def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None: """ diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c5f674d8c..08fd27573 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -7,7 +7,7 @@ import itertools import threading import time from collections import defaultdict -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import contextmanager from copy import copy, deepcopy from dataclasses import dataclass, replace @@ -172,6 +172,7 @@ from vllm.v1.spec_decode.ngram_proposer_gpu import ( update_scheduler_for_invalid_drafts, ) from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer +from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker import mamba_utils @@ -570,6 +571,7 @@ class GPUModelRunner( self.rejection_sampler = RejectionSampler(self.sampler) self.num_spec_tokens = 0 + self.valid_sampled_token_count_gpu: torch.Tensor | None = None if self.speculative_config: self.num_spec_tokens = self.speculative_config.num_speculative_tokens draft_config = self.speculative_config.draft_model_config @@ -577,6 +579,9 @@ class GPUModelRunner( self.effective_drafter_max_model_len = draft_config.max_model_len else: self.effective_drafter_max_model_len = self.max_model_len + self.use_async_spec_decode = ( + self.use_async_scheduling and self.num_spec_tokens > 0 + ) # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -659,11 +664,31 @@ class GPUModelRunner( # Persistent buffers for CUDA graphs. self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32) - self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + self.positions = torch.zeros( + self.max_num_tokens, dtype=torch.int64, device=self.device + ) self.query_start_loc = self._make_buffer( self.max_num_reqs + 1, dtype=torch.int32 ) - self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.seq_lens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.optimistic_seq_lens_cpu = torch.zeros( + self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory + ) + self.num_computed_tokens = torch.zeros( + self.max_num_reqs, dtype=torch.int32, device=self.device + ) + self.prev_num_draft_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64) + # Maps current batch position -> previous batch position (-1 for new reqs) + self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64) + self.num_scheduled_tokens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) + self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) if self.dcp_world_size > 1: self.dcp_local_seq_lens = self._make_buffer( @@ -683,7 +708,7 @@ class GPUModelRunner( self.max_num_reqs, dtype=torch.int32 ) self.num_accepted_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int64 + self.max_num_reqs, dtype=torch.int32 ) # Only relevant for multimodal models @@ -722,12 +747,14 @@ class GPUModelRunner( # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: IntermediateTensors | None = None - # OPTIMIZATION: Cache the tensors rather than creating them every step. - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange( - max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens), - dtype=np.int64, - ) + # OPTIMIZATION: Cache the arange tensors rather than creating them + # every step. Keep in int64 to avoid overflow with long context. + # - arange_np: immutable [0, 1, 2, ...] used as source for batched computation + # - query_pos: CpuGpuBuffer for the computed batched arange result + arange_size = max(self.max_num_reqs + 1, self.max_num_tokens) + self.arange_np = np.arange(arange_size, dtype=np.int64) + self.query_pos = self._make_buffer(arange_size, dtype=torch.int64) + self._arange_scratch = np.empty(arange_size, dtype=np.int64) # Layer pairings for cross-layer KV sharing. # If an Attention layer `layer_name` is in the keys of this dict, it @@ -812,7 +839,7 @@ class GPUModelRunner( self.valid_sampled_token_count_copy_stream = torch.cuda.Stream() self.valid_sampled_token_count_cpu = torch.empty( self.max_num_reqs, - dtype=torch.int64, + dtype=torch.int32, device="cpu", pin_memory=self.pin_memory, ) @@ -903,13 +930,13 @@ class GPUModelRunner( return self.mrope_positions.gpu[:, :num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, :num_tokens] - return self.positions.gpu[:num_tokens] + return self.positions[:num_tokens] else: if self.uses_mrope: return self.mrope_positions.gpu[:, num_tokens] if self.uses_xdrope_dim > 0: return self.xdrope_positions.gpu[:, num_tokens] - return self.positions.gpu[num_tokens] + return self.positions[num_tokens] def _make_buffer( self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True @@ -953,7 +980,7 @@ class GPUModelRunner( if len(token_type_id_requests) == 0: return model_kwargs - seq_lens = self.seq_lens.gpu[:num_reqs] + seq_lens = self.seq_lens[:num_reqs] token_type_ids = [] for i in range(num_reqs): @@ -1021,7 +1048,7 @@ class GPUModelRunner( def _sync_device(self) -> None: torch.accelerator.synchronize() - def _update_states(self, scheduler_output: "SchedulerOutput") -> None: + def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None: """Update the cached states and the persistent batch with the scheduler output. @@ -1086,6 +1113,8 @@ class GPUModelRunner( ngram_gpu_new_reqs: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = [] + deferred_spec_decode_corrections = [] + # Add new requests to the cached states. for new_req_data in scheduler_output.scheduled_new_reqs: req_id = new_req_data.req_id @@ -1172,10 +1201,8 @@ class GPUModelRunner( scheduler_output, self.input_batch.req_id_to_index, ) - - # Wait until valid_sampled_tokens_count is copied to cpu, - # then use it to update actual num_computed_tokens of each request. - valid_sampled_token_count = self._get_valid_sampled_token_count() + if self.use_async_spec_decode: + self.prev_num_draft_tokens.np.fill(0) for i, req_id in enumerate(req_data.req_ids): req_state = self.requests[req_id] @@ -1202,15 +1229,30 @@ class GPUModelRunner( if req_index is None: req_state.prev_num_draft_len = 0 else: - assert self.input_batch.prev_req_id_to_index is not None - prev_req_index = self.input_batch.prev_req_id_to_index[req_id] - num_accepted = valid_sampled_token_count[prev_req_index] - 1 - num_rejected = req_state.prev_num_draft_len - num_accepted - num_computed_tokens -= num_rejected - req_state.output_token_ids.extend([-1] * num_accepted) + # Optimistically assume all accepted; queue up a correction + # to be called after the model forward to preserve async + # scheduling. Corrected on GPU in _prepare_inputs. + optimistic_num_accepted = req_state.prev_num_draft_len + req_state.output_token_ids.extend([-1] * optimistic_num_accepted) - if is_ngram_gpu and num_accepted > 0 and req_index is not None: - self.input_batch.num_tokens_no_spec[req_index] += num_accepted + deferred_spec_decode_corrections.append( + (req_id, optimistic_num_accepted, req_state) + ) + + prev_req_index = ( + self.input_batch.prev_req_id_to_index.get(req_id) + if self.input_batch.prev_req_id_to_index + else None + ) + if prev_req_index is not None: + self.prev_num_draft_tokens.np[prev_req_index] = ( + optimistic_num_accepted + ) + + if is_ngram_gpu and optimistic_num_accepted > 0: + self.input_batch.num_tokens_no_spec[req_index] += ( + optimistic_num_accepted + ) # Update the cached states. req_state.num_computed_tokens = num_computed_tokens @@ -1238,7 +1280,8 @@ class GPUModelRunner( ) elif num_output_tokens < len(req_state.output_token_ids): # Some output tokens were discarded due to a sync-KV-load - # failure. Align the cached state. + # failure, or output_token_ids was inflated by the optimistic + # extend above (async spec decode). Align the cached state. del req_state.output_token_ids[num_output_tokens:] if req_index is not None: end_idx = ( @@ -1326,6 +1369,40 @@ class GPUModelRunner( _pinned_val_buf=self._ngram_pinned_val_buf, ) + if deferred_spec_decode_corrections: + + def correct_spec_decode_token_counts(): + valid_sampled_token_count = self._get_valid_sampled_token_count() + if not valid_sampled_token_count: + return + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + if not prev_req_id_to_index: + return + for ( + req_id, + optimistic_num_accepted, + req_state, + ) in deferred_spec_decode_corrections: + prev_req_index = prev_req_id_to_index.get(req_id) + if prev_req_index is None: + continue + num_accepted = valid_sampled_token_count[prev_req_index] - 1 + correction = optimistic_num_accepted - num_accepted + req_state.num_computed_tokens -= correction + cur_req_index = self.input_batch.req_id_to_index.get(req_id) + if cur_req_index is None: + continue + self.input_batch.num_computed_tokens_cpu[cur_req_index] -= ( + correction + ) + if is_ngram_gpu and correction > 0: + self.input_batch.num_tokens_no_spec[cur_req_index] -= correction + self.num_tokens_no_spec_gpu[cur_req_index] -= correction + + return correct_spec_decode_token_counts + else: + return None + def _update_states_after_model_execute( self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput" ) -> None: @@ -1340,6 +1417,9 @@ class GPUModelRunner( if not self.speculative_config or not self.model_config.is_hybrid: return + # TODO: Remove .cpu() sync to enable fully async for hybrid model; + # Use num_computed_tokens.gpu instead of req.num_computed_tokens to + # support aligned mamba cache mode. # Find the number of accepted tokens for each sequence. num_reqs = output_token_ids.size(0) self.num_accepted_tokens.gpu[:num_reqs] = ( @@ -1486,12 +1566,14 @@ class GPUModelRunner( def _get_cumsum_and_arange( self, num_tokens: np.ndarray, + arange_out: np.ndarray, cumsum_dtype: np.dtype | None = None, - ) -> tuple[np.ndarray, np.ndarray]: + ) -> np.ndarray: """Get the cumulative sum and batched arange of the given array. - # E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]) - # Equivalent to but faster than: - # np.concatenate([np.arange(n) for n in num_tokens]) + E.g., [2, 5, 3] -> [2, 7, 10], arange written to + arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]. + Equivalent to but faster than: + np.concatenate([np.arange(n) for n in num_tokens]) """ # Step 1. [2, 5, 3] -> [2, 7, 10] cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype) @@ -1499,13 +1581,33 @@ class GPUModelRunner( # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens) # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange = self.arange_np[:total_num_tokens] - cumsums_offsets + np.subtract( + self.arange_np[:total_num_tokens], + cumsums_offsets, + out=arange_out[:total_num_tokens], + ) - return cu_num_tokens, arange + return cu_num_tokens + + def _compute_prev_positions(self, num_reqs: int) -> None: + """Build prev_positions mapping: current pos -> previous pos (-1 if new). + + Populates self.prev_positions.np[:num_reqs] with the mapping. + """ + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + prev_positions = self.prev_positions.np[:num_reqs] + + if not prev_req_id_to_index: + prev_positions.fill(-1) + return + + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + prev_positions[i] = prev_req_id_to_index.get(req_id, -1) def _prepare_input_ids( self, scheduler_output: "SchedulerOutput", + num_reqs: int, total_num_scheduled_tokens: int, cu_num_tokens: np.ndarray, ) -> None: @@ -1513,7 +1615,11 @@ class GPUModelRunner( Carefully handles the `prev_sampled_token_ids` which can be cached from the previous engine iteration, in which case those tokens on the - GPU need to be copied into the corresponding slots into input_ids.""" + GPU need to be copied into the corresponding slots into input_ids. + + Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos + (-1 for new requests). + """ if self.input_batch.prev_sampled_token_ids is None: # Normal scheduling case @@ -1526,47 +1632,50 @@ class GPUModelRunner( # Async scheduling case, where some decode requests from the previous # iteration won't have entries in input_ids_cpu and need to be copied # on the GPU from prev_sampled_token_ids. - prev_req_id_to_index = self.input_batch.prev_req_id_to_index - assert prev_req_id_to_index is not None + prev_positions = self.prev_positions.np[:num_reqs] + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens sample_flattened_indices: list[int] = [] spec_flattened_indices: list[int] = [] - prev_common_req_indices: list[int] = [] prev_draft_token_indices: list[int] = [] - indices_match = True + prev_indices: list[int] = [] + common_indices_match = True max_flattened_index = -1 total_num_spec_tokens = 0 - scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens - for req_id, cur_index in self.input_batch.req_id_to_index.items(): - if (prev_index := prev_req_id_to_index.get(req_id)) is not None: - prev_common_req_indices.append(prev_index) - # We need to compute the flattened input_ids index of the - # last token in each common request. - draft_len = len(scheduled_spec_tokens.get(req_id, ())) - total_num_spec_tokens += draft_len - flattened_index = cu_num_tokens[cur_index].item() - 1 - # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] - # sample_flattened_indices = [0, 2, 5] - # spec_flattened_indices = [1, 3, 4, 6, 7] - sample_flattened_indices.append(flattened_index - draft_len) - spec_flattened_indices.extend( - range(flattened_index - draft_len + 1, flattened_index + 1) - ) - start = prev_index * self.num_spec_tokens - # prev_draft_token_indices is used to find which draft_tokens_id - # should be copied to input_ids - # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] - # flatten draft_tokens_id [1,2,3,4,5,6] - # draft_len of each request [1, 2, 1] - # then prev_draft_token_indices is [0, 2, 3, 4] - prev_draft_token_indices.extend(range(start, start + draft_len)) - indices_match &= prev_index == flattened_index - max_flattened_index = max(max_flattened_index, flattened_index) + for cur_index in range(num_reqs): + prev_index = prev_positions[cur_index] + if prev_index < 0: + continue + prev_indices.append(prev_index) + req_id = self.input_batch.req_ids[cur_index] + # We need to compute the flattened input_ids index of the + # last token in each common request. + draft_len = len(scheduled_spec_tokens.get(req_id, ())) + total_num_spec_tokens += draft_len + flattened_index = cu_num_tokens[cur_index].item() - 1 + # example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2] + # sample_flattened_indices = [0, 2, 5] + # spec_flattened_indices = [1, 3, 4, 6, 7] + sample_flattened_indices.append(flattened_index - draft_len) + spec_flattened_indices.extend( + range(flattened_index - draft_len + 1, flattened_index + 1) + ) + start = prev_index * self.num_spec_tokens + # prev_draft_token_indices is used to find which draft_tokens_id + # should be copied to input_ids + # example: prev draft_tokens_id [[1,2], [3,4], [5, 6]] + # flatten draft_tokens_id [1,2,3,4,5,6] + # draft_len of each request [1, 2, 1] + # then prev_draft_token_indices is [0, 2, 3, 4] + prev_draft_token_indices.extend(range(start, start + draft_len)) + common_indices_match &= prev_index == flattened_index + max_flattened_index = max(max_flattened_index, flattened_index) + num_common_tokens = len(sample_flattened_indices) total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens if num_common_tokens < total_without_spec: # If not all requests are decodes from the last iteration, - # We need to copy the input_ids_cpu to the GPU first. + # we need to copy the input_ids_cpu to the GPU first. self.input_ids.copy_to_gpu(total_num_scheduled_tokens) if self.enable_prompt_embeds: self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens) @@ -1575,7 +1684,7 @@ class GPUModelRunner( # No requests in common with the previous iteration # So input_ids.cpu will have all the input ids. return - if indices_match and max_flattened_index == (num_common_tokens - 1): + if common_indices_match and max_flattened_index == (num_common_tokens - 1): # Common-case optimization: the batch is unchanged # and no reordering happened. # The indices are both the same permutation of 0..N-1 so @@ -1592,7 +1701,7 @@ class GPUModelRunner( sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) prev_common_req_indices_tensor = torch.tensor( - prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory + prev_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True) self.input_ids.gpu.scatter_( dim=0, @@ -1696,15 +1805,15 @@ class GPUModelRunner( req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] - # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens) + # self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + cu_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np + ) # Get positions. - positions_np = self.positions.np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, + positions_np = ( + self.input_batch.num_computed_tokens_cpu[req_indices] + + self.query_pos.np[: cu_num_tokens[-1]] ) # Calculate M-RoPE positions. @@ -1782,9 +1891,6 @@ class GPUModelRunner( output_idx += num_sched - self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np) - self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens) - # Prepare the attention metadata. self.query_start_loc.np[0] = 0 self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens @@ -1794,12 +1900,21 @@ class GPUModelRunner( self.query_start_loc.copy_to_gpu() query_start_loc = self.query_start_loc.gpu[: num_reqs + 1] - self.seq_lens.np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens + # Compute optimistic seq_lens (assumes all draft tokens from previous + # iteration accepted). Store in optimistic_seq_lens_cpu for use by + # _build_attention_metadata (max_seq_len) and discard_request_mask. + # seq_lens (GPU) will be computed later using the same optimistic values. + torch.add( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + torch.from_numpy(num_scheduled_tokens), + out=self.optimistic_seq_lens_cpu[:num_reqs], ) - # Fill unused with 0 for full cuda graph mode. - self.seq_lens.np[num_reqs:].fill(0) - self.seq_lens.copy_to_gpu() + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + + # Build prev_positions mapping: current pos -> prev pos (-1 if new). + # Used for gathering from previous iteration's GPU tensors. + prev_req_id_to_index = self.input_batch.prev_req_id_to_index + self._compute_prev_positions(num_reqs) num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids] num_tokens_np = np.array(num_tokens, dtype=np.int32) @@ -1807,13 +1922,78 @@ class GPUModelRunner( # Record which requests should not be sampled, # so that we could clear the sampled tokens before returning self.discard_request_mask.np[:num_reqs] = ( - self.seq_lens.np[:num_reqs] < num_tokens_np + self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np ) self.discard_request_mask.copy_to_gpu(num_reqs) + # Sync num_accepted_tokens from CPU (set by + # _update_states_after_model_execute for hybrid models). + if self.num_accepted_tokens_event is not None: + self.num_accepted_tokens_event.synchronize() + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.np[num_reqs:].fill(1) + self.num_accepted_tokens.copy_to_gpu() + else: + self.num_accepted_tokens.np.fill(1) + self.num_accepted_tokens.gpu.fill_(1) + + # Update num_computed_tokens on GPU. In async spec decode, + # CPU values are optimistic (all drafts accepted). The kernel + # corrects on GPU using the previous step's + # valid_sampled_token_count_gpu. Otherwise, just copy from CPU. + if ( + self.use_async_spec_decode + and self.valid_sampled_token_count_gpu is not None + and prev_req_id_to_index + ): + self.prev_positions.copy_to_gpu(num_reqs) + self.prev_num_draft_tokens.copy_to_gpu() + cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to( + device=self.device, non_blocking=True + ) + update_num_computed_tokens_for_batch_change( + self.num_computed_tokens, + self.num_accepted_tokens.gpu[:num_reqs], + self.prev_positions.gpu[:num_reqs], + self.valid_sampled_token_count_gpu, + self.prev_num_draft_tokens.gpu, + cpu_values, + ) + else: + self.num_computed_tokens[:num_reqs].copy_( + self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs], + non_blocking=True, + ) + + self.req_indices.np[:total_num_scheduled_tokens] = req_indices + self.req_indices.copy_to_gpu(total_num_scheduled_tokens) + req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens] + + self.query_pos.copy_to_gpu(total_num_scheduled_tokens) + self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens + self.num_scheduled_tokens.copy_to_gpu(num_reqs) + num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs] + self.positions[:total_num_scheduled_tokens] = ( + self.num_computed_tokens[req_indices_gpu].to(torch.int64) + + self.query_pos.gpu[:total_num_scheduled_tokens] + ) + self.seq_lens[:num_reqs] = ( + self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu + ) + self.seq_lens[num_reqs:].fill_(0) + + self.input_batch.block_table.compute_slot_mapping( + num_reqs, + self.query_start_loc.gpu[: num_reqs + 1], + self.positions[:total_num_scheduled_tokens], + ) + # Copy the tensors to the GPU. self._prepare_input_ids( scheduler_output, + num_reqs, total_num_scheduled_tokens, cu_num_tokens, ) @@ -1830,9 +2010,14 @@ class GPUModelRunner( self.xdrope_positions.cpu[:, :total_num_scheduled_tokens], non_blocking=True, ) - else: - # Common case (1D positions) - self.positions.copy_to_gpu(total_num_scheduled_tokens) + if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0): + drift = self.num_computed_tokens[req_indices_gpu].to( + torch.int64 + ) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to( + device=self.device, dtype=torch.int64, non_blocking=True + ) + target = self.mrope_positions if self.uses_mrope else self.xdrope_positions + target.gpu[:, :total_num_scheduled_tokens] += drift use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -1857,12 +2042,13 @@ class GPUModelRunner( draft_token_ids, ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) + draft_len = len(draft_token_ids) + num_draft_tokens[req_idx] = draft_len if ( self.input_batch.num_computed_tokens_cpu[req_idx] >= self.input_batch.num_prompt_tokens[req_idx] ): - num_decode_draft_tokens[req_idx] = len(draft_token_ids) + num_decode_draft_tokens[req_idx] = draft_len spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens ) @@ -1924,16 +2110,7 @@ class GPUModelRunner( # window size when capturing to make sure the correct kernel is selected. max_seq_len = self.max_model_len else: - max_seq_len = self.seq_lens.np[:num_reqs].max().item() - - if use_spec_decode: - if self.num_accepted_tokens_event is not None: - self.num_accepted_tokens_event.synchronize() - self.num_accepted_tokens.np[:num_reqs] = ( - self.input_batch.num_accepted_tokens_cpu[:num_reqs] - ) - self.num_accepted_tokens.np[num_reqs:].fill(1) - self.num_accepted_tokens.copy_to_gpu() + max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item() kv_cache_groups = self.kv_cache_config.kv_cache_groups @@ -1963,22 +2140,29 @@ class GPUModelRunner( attn_gid = self.routed_experts_attn_gid slot_mapping_attn = slot_mappings[attn_gid] self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() - # Compute is_prefilling: True if request is still in prefill phase - # (num_computed_tokens < num_prompt_tokens). Used by mamba backends to - # distinguish actual decodes from short extends. num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[ :num_reqs_padded ] num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[ :num_reqs_padded ] + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded] + + # is_prefilling: True if request is still in prefill phase. + # Used by mamba backends to distinguish actual decodes from + # short extends. is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu + if self.use_async_spec_decode: + # GPU tensors are authoritative in async mode. + seq_lens_cpu = None + num_computed_tokens_cpu = None + cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], - seq_lens=self.seq_lens.gpu[:num_reqs_padded], - _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], + seq_lens=self.seq_lens[:num_reqs_padded], + _seq_lens_cpu=seq_lens_cpu, _num_computed_tokens_cpu=num_computed_tokens_cpu, num_reqs=num_reqs_padded, num_actual_tokens=num_tokens_padded, @@ -1992,7 +2176,7 @@ class GPUModelRunner( if self.dcp_world_size > 1: self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( - self.seq_lens.cpu[:num_reqs], + self.optimistic_seq_lens_cpu[:num_reqs], self.dcp_world_size, self.dcp_rank, self.parallel_config.cp_kv_cache_interleave_size, @@ -2396,33 +2580,34 @@ class GPUModelRunner( # [4, 1, 3, 1, 2] num_sampled_tokens = num_draft_tokens + 1 - # Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11] - # arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] - cu_num_sampled_tokens, arange = self._get_cumsum_and_arange( - num_sampled_tokens, cumsum_dtype=np.int32 + # Step 1. + # cu_num_sampled_tokens: [4, 5, 8, 9, 11] + # _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1] + cu_num_sampled_tokens = self._get_cumsum_and_arange( + num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32 ) # Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] logits_indices = np.repeat( cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens ) # Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208] - logits_indices += arange + logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]] # Compute the bonus logits indices. bonus_logits_indices = cu_num_sampled_tokens - 1 # Compute the draft logits indices. # cu_num_draft_tokens: [3, 3, 5, 5, 6] - # arange: [0, 1, 2, 0, 1, 0] - cu_num_draft_tokens, arange = self._get_cumsum_and_arange( - num_draft_tokens, cumsum_dtype=np.int32 + # _arange_scratch[:6]: [0, 1, 2, 0, 1, 0] + cu_num_draft_tokens = self._get_cumsum_and_arange( + num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32 ) # [0, 0, 0, 5, 5, 9] target_logits_indices = np.repeat( cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens ) # [0, 1, 2, 5, 6, 9] - target_logits_indices += arange + target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]] # TODO: Optimize the CPU -> GPU copy. cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( @@ -2924,7 +3109,7 @@ class GPUModelRunner( ) hidden_states = hidden_states[:num_scheduled_tokens] - seq_lens_cpu = self.seq_lens.cpu[:num_reqs] + seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs] pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( @@ -3083,9 +3268,9 @@ class GPUModelRunner( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_input_tokens] else: - positions = self.positions.gpu[:num_input_tokens] + positions = self.positions[:num_input_tokens] if num_input_tokens > num_scheduled_tokens: - self.positions.gpu[num_scheduled_tokens:num_input_tokens].zero_() + self.positions[num_scheduled_tokens:num_input_tokens].zero_() if is_first_rank: intermediate_tensors = None @@ -3610,7 +3795,7 @@ class GPUModelRunner( self.synchronize_input_prep(), ): # Update persistent batch states. - self._update_states(scheduler_output) + deferred_state_corrections_fn = self._update_states(scheduler_output) if has_ec_transfer() and not get_ec_transfer().is_consumer: with self.maybe_get_ec_connector_output( @@ -3723,6 +3908,12 @@ class GPUModelRunner( pad_attn = cudagraph_mode == CUDAGraphMode.FULL if self.cache_config.mamba_cache_mode == "align": + # preprocess_mamba reads req_state.num_computed_tokens (CPU) + # to decide copy operations, so we must apply deferred + # corrections before it runs. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() + deferred_state_corrections_fn = None mamba_utils.preprocess_mamba( scheduler_output, self.kv_cache_config, @@ -3734,6 +3925,14 @@ class GPUModelRunner( self.model.get_mamba_state_copy_func(), self._get_mamba_copy_bufs(), ) + # preprocess_mamba resets num_accepted_tokens_cpu to 1 + # for requests whose state was copied to a new block. + # Re-sync to GPU so the mamba kernel reads from the + # correct initial state slot (init_token_idx = 0). + self.num_accepted_tokens.np[:num_reqs] = ( + self.input_batch.num_accepted_tokens_cpu[:num_reqs] + ) + self.num_accepted_tokens.copy_to_gpu(num_reqs) use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices @@ -3894,6 +4093,12 @@ class GPUModelRunner( slot_mappings, ) self.kv_connector_output = kv_connector_output + + # Now the batch has been launched we can wait for corrections from the + # previous model forward without breaking async scheduling. + if deferred_state_corrections_fn: + deferred_state_corrections_fn() + return None @torch.inference_mode @@ -3958,6 +4163,7 @@ class GPUModelRunner( self._draft_token_ids = None self._draft_token_req_ids = None + self.valid_sampled_token_count_gpu = None self.input_batch.prev_sampled_token_ids = None def propose_draft_token_ids(sampled_token_ids): @@ -4002,7 +4208,7 @@ class GPUModelRunner( assert spec_decode_common_attn_metadata is not None next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, + self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -4237,6 +4443,9 @@ class GPUModelRunner( counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True) self.valid_sampled_token_count_event.record() + if self.use_async_spec_decode: + # Stash for GPU-side correction in _prepare_inputs. + self.valid_sampled_token_count_gpu = valid_sampled_tokens_count self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1) def _get_valid_sampled_token_count(self) -> list[int]: @@ -4366,7 +4575,7 @@ class GPUModelRunner( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -4405,7 +4614,7 @@ class GPUModelRunner( ) next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( - common_attn_metadata, + self.optimistic_seq_lens_cpu, sampled_token_ids, self.requests, self.input_batch, @@ -5148,14 +5357,19 @@ class GPUModelRunner( # In the mixed batch mode (used for FI warmup), we use # shorter sequence lengths to run faster. # TODO(luka) better system for describing dummy batches - seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment] + seq_lens = torch.tensor( # type: ignore[assignment] + [1] * num_decode_tokens + [num_prefill_tokens + 1], + dtype=torch.int, + ) else: seq_lens = max_query_len # type: ignore[assignment] - self.seq_lens.np[:num_reqs] = seq_lens - self.seq_lens.np[num_reqs:] = 0 - self.seq_lens.copy_to_gpu() + self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens + self.optimistic_seq_lens_cpu[num_reqs:].fill_(0) + self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True) - cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens) + cum_num_tokens = self._get_cumsum_and_arange( + num_scheduled_tokens, self.query_pos.np + ) self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens self.query_start_loc.copy_to_gpu() @@ -5201,7 +5415,7 @@ class GPUModelRunner( elif self.uses_xdrope_dim > 0: positions = self.xdrope_positions.gpu[:, :num_tokens_padded] else: - positions = self.positions.gpu[:num_tokens_padded] + positions = self.positions[:num_tokens_padded] if get_pp_group().is_first_rank: intermediate_tensors = None