[Async][Spec Decoding] Zero-bubble async scheduling + spec decoding (#32951)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Co-authored-by: zhrrr <43847754+izhuhaoran@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
@@ -177,7 +177,7 @@ def test_prepare_next_token_ids():
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = (
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -187,7 +187,7 @@ def test_prepare_next_token_ids_padded():
|
||||
)
|
||||
|
||||
next_token_ids, valid_sampled_tokens_count = proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
|
||||
@@ -766,6 +766,19 @@ class VllmConfig: # type: ignore[misc]
|
||||
else:
|
||||
self.parallel_config.disable_nccl_for_dp_synchronization = False
|
||||
|
||||
if (
|
||||
self.speculative_config is not None
|
||||
and self.scheduler_config.async_scheduling
|
||||
and self.model_config is not None
|
||||
and not self.model_config.disable_cascade_attn
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling cascade attention (not yet compatible with "
|
||||
"async speculative decoding).",
|
||||
scope="local",
|
||||
)
|
||||
self.model_config.disable_cascade_attn = True
|
||||
|
||||
if (
|
||||
self.model_config is not None
|
||||
and self.model_config.multimodal_config is not None
|
||||
|
||||
@@ -71,7 +71,6 @@ class SpecDecodeBaseProposer:
|
||||
self.method = self.speculative_config.method
|
||||
self.pass_hidden_states_to_model = pass_hidden_states_to_model
|
||||
|
||||
self.runner = runner
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
@@ -424,8 +423,6 @@ class SpecDecodeBaseProposer:
|
||||
)
|
||||
)
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
per_layer_attn_metadata: dict[str, object] = {}
|
||||
for attn_group in self.draft_attn_groups:
|
||||
attn_metadata = attn_group.get_metadata_builder().build_for_drafting(
|
||||
@@ -821,7 +818,7 @@ class SpecDecodeBaseProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -836,11 +833,10 @@ class SpecDecodeBaseProposer:
|
||||
"""
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
seq_lens_list = seq_lens_cpu[:num_reqs].tolist()
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
for i in range(num_reqs)
|
||||
],
|
||||
dtype=np.int32,
|
||||
@@ -925,7 +921,7 @@ class SpecDecodeBaseProposer:
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
|
||||
causal=True,
|
||||
|
||||
@@ -286,7 +286,7 @@ class ExtractHiddenStatesProposer:
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
seq_lens: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
@@ -303,11 +303,10 @@ class ExtractHiddenStatesProposer:
|
||||
device = sampled_token_ids.device
|
||||
|
||||
# Compute backup tokens for discarded / invalid requests
|
||||
seq_lens_list = seq_lens[:num_reqs].tolist()
|
||||
backup_tokens_gpu = torch.tensor(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(seq_lens_list[i])
|
||||
for i in range(num_reqs)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig, replace
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
@@ -463,3 +464,36 @@ def copy_and_expand_eagle_inputs_kernel(
|
||||
out_idx,
|
||||
mask=is_new_token_region & in_bounds,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def update_num_computed_tokens_for_batch_change(
|
||||
num_computed_tokens: torch.Tensor,
|
||||
num_accepted_tokens: torch.Tensor,
|
||||
prev_positions: torch.Tensor,
|
||||
valid_sampled_token_count: torch.Tensor,
|
||||
prev_num_draft_tokens: torch.Tensor,
|
||||
cpu_num_computed_tokens: torch.Tensor,
|
||||
) -> None:
|
||||
"""Correct num_computed_tokens for async spec decode drift.
|
||||
|
||||
Requests that had drafts: corrected = prev_gpu + valid_count.
|
||||
New requests or non-draft (e.g. prefills): use CPU value directly.
|
||||
"""
|
||||
# Clamp because prev_positions can be -1 for new requests
|
||||
gather_indices = prev_positions.clamp(min=0)
|
||||
|
||||
valid_counts = valid_sampled_token_count[gather_indices]
|
||||
prev_computed = num_computed_tokens[gather_indices]
|
||||
prev_drafts = prev_num_draft_tokens[gather_indices]
|
||||
|
||||
participating = (prev_positions >= 0) & (prev_drafts > 0)
|
||||
corrected = prev_computed + valid_counts.int()
|
||||
|
||||
n = prev_positions.shape[0]
|
||||
num_computed_tokens[:n].copy_(
|
||||
torch.where(participating, corrected, cpu_num_computed_tokens)
|
||||
)
|
||||
num_accepted_tokens.copy_(
|
||||
torch.where(participating, valid_counts, num_accepted_tokens)
|
||||
)
|
||||
|
||||
@@ -6,7 +6,9 @@ import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.cp_utils import get_total_cp_world_size
|
||||
|
||||
@@ -131,71 +133,33 @@ class BlockTable:
|
||||
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
|
||||
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
self,
|
||||
num_reqs: int,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> None:
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
num_tokens = positions.shape[0]
|
||||
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
if total_cp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * total_cp_world_size
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.cp_kv_cache_interleave_size
|
||||
% total_cp_world_size
|
||||
== total_cp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
|
||||
* self.cp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping.np[: req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1
|
||||
)
|
||||
else:
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req + positions // self.block_size
|
||||
)
|
||||
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(
|
||||
block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping.np[: req_indices.shape[0]],
|
||||
)
|
||||
_compute_slot_mapping_kernel[(num_reqs + 1,)](
|
||||
num_tokens,
|
||||
self.max_num_batched_tokens,
|
||||
query_start_loc,
|
||||
positions,
|
||||
self.block_table.gpu,
|
||||
self.block_table.gpu.stride(0),
|
||||
self.block_size,
|
||||
self.slot_mapping.gpu,
|
||||
TOTAL_CP_WORLD_SIZE=total_cp_world_size,
|
||||
TOTAL_CP_RANK=total_cp_rank,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE=self.cp_kv_cache_interleave_size,
|
||||
PAD_ID=PAD_SLOT_ID,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
@@ -320,19 +284,18 @@ class MultiGroupBlockTable:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def compute_slot_mapping(
|
||||
self, req_indices: np.ndarray, positions: np.ndarray
|
||||
self,
|
||||
num_reqs: int,
|
||||
query_start_loc: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.compute_slot_mapping(req_indices, positions)
|
||||
block_table.compute_slot_mapping(num_reqs, query_start_loc, positions)
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_block_table(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit_slot_mapping(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
@@ -340,3 +303,61 @@ class MultiGroupBlockTable:
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_slot_mapping_kernel(
|
||||
num_tokens,
|
||||
max_num_tokens,
|
||||
query_start_loc_ptr, # [num_reqs + 1], int32
|
||||
positions_ptr, # [num_tokens], int64
|
||||
block_table_ptr, # [max_num_reqs, max_num_blocks_per_req], int32 (flat)
|
||||
block_table_stride, # max_num_blocks_per_req
|
||||
block_size,
|
||||
slot_mapping_ptr, # [max_num_tokens], int64
|
||||
TOTAL_CP_WORLD_SIZE: tl.constexpr,
|
||||
TOTAL_CP_RANK: tl.constexpr,
|
||||
CP_KV_CACHE_INTERLEAVE_SIZE: tl.constexpr,
|
||||
PAD_ID: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
|
||||
if req_idx == tl.num_programs(0) - 1:
|
||||
# Pad remaining slots for CUDA graph compatibility.
|
||||
for i in range(num_tokens, max_num_tokens, BLOCK_SIZE):
|
||||
offsets = i + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
slot_mapping_ptr + offsets,
|
||||
PAD_ID,
|
||||
mask=offsets < max_num_tokens,
|
||||
)
|
||||
return
|
||||
|
||||
start_idx = tl.load(query_start_loc_ptr + req_idx).to(tl.int64)
|
||||
end_idx = tl.load(query_start_loc_ptr + req_idx + 1).to(tl.int64)
|
||||
|
||||
virtual_block_size = block_size * TOTAL_CP_WORLD_SIZE
|
||||
row_offset = req_idx * block_table_stride
|
||||
for i in range(start_idx, end_idx, BLOCK_SIZE):
|
||||
offsets = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < end_idx
|
||||
pos = tl.load(positions_ptr + offsets, mask=mask, other=0)
|
||||
block_indices = pos // virtual_block_size
|
||||
block_numbers = tl.load(block_table_ptr + row_offset + block_indices).to(
|
||||
tl.int64
|
||||
)
|
||||
|
||||
virtual_block_offsets = pos - block_indices * virtual_block_size
|
||||
is_local = (
|
||||
virtual_block_offsets // CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
) % TOTAL_CP_WORLD_SIZE == TOTAL_CP_RANK
|
||||
local_block_offsets = (
|
||||
virtual_block_offsets // (TOTAL_CP_WORLD_SIZE * CP_KV_CACHE_INTERLEAVE_SIZE)
|
||||
) * CP_KV_CACHE_INTERLEAVE_SIZE + (
|
||||
virtual_block_offsets % CP_KV_CACHE_INTERLEAVE_SIZE
|
||||
)
|
||||
|
||||
slot_ids = block_numbers * block_size + local_block_offsets
|
||||
slot_ids = tl.where(is_local, slot_ids, PAD_ID)
|
||||
tl.store(slot_mapping_ptr + offsets, slot_ids, mask=mask)
|
||||
|
||||
@@ -219,7 +219,7 @@ class InputBatch:
|
||||
|
||||
# Speculative decoding
|
||||
self.num_accepted_tokens_cpu_tensor = torch.ones(
|
||||
(max_num_reqs,), dtype=torch.int64, device="cpu", pin_memory=pin_memory
|
||||
(max_num_reqs,), dtype=torch.int32, device="cpu", pin_memory=pin_memory
|
||||
)
|
||||
self.num_accepted_tokens_cpu = self.num_accepted_tokens_cpu_tensor.numpy()
|
||||
|
||||
@@ -989,13 +989,15 @@ class InputBatch:
|
||||
continue
|
||||
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
|
||||
# Also account for case where there may be a smaller number of
|
||||
# output placeholders (tokens can be discarded after a kv-load failure).
|
||||
# output placeholders (tokens can be discarded after kv-load
|
||||
# failure) or a larger number (async spec decode adds optimistic
|
||||
# placeholders that may exceed the actual acceptance count).
|
||||
first_placeholder = req_output_token_ids.index(-1)
|
||||
num_placeholders = len(req_output_token_ids) - first_placeholder
|
||||
num_to_replace = min(num_sampled_ids, num_placeholders)
|
||||
del new_ids[num_to_replace:]
|
||||
end_index = first_placeholder + num_to_replace
|
||||
req_output_token_ids[first_placeholder:end_index] = new_ids
|
||||
req_output_token_ids[first_placeholder:] = new_ids
|
||||
# ^ Implicitly resizes to (first_placeholder + num_to_replace)
|
||||
|
||||
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ import itertools
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass, replace
|
||||
@@ -172,6 +172,7 @@ from vllm.v1.spec_decode.ngram_proposer_gpu import (
|
||||
update_scheduler_for_invalid_drafts,
|
||||
)
|
||||
from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer
|
||||
from vllm.v1.spec_decode.utils import update_num_computed_tokens_for_batch_change
|
||||
from vllm.v1.structured_output.utils import apply_grammar_bitmask
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker import mamba_utils
|
||||
@@ -570,6 +571,7 @@ class GPUModelRunner(
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
|
||||
self.num_spec_tokens = 0
|
||||
self.valid_sampled_token_count_gpu: torch.Tensor | None = None
|
||||
if self.speculative_config:
|
||||
self.num_spec_tokens = self.speculative_config.num_speculative_tokens
|
||||
draft_config = self.speculative_config.draft_model_config
|
||||
@@ -577,6 +579,9 @@ class GPUModelRunner(
|
||||
self.effective_drafter_max_model_len = draft_config.max_model_len
|
||||
else:
|
||||
self.effective_drafter_max_model_len = self.max_model_len
|
||||
self.use_async_spec_decode = (
|
||||
self.use_async_scheduling and self.num_spec_tokens > 0
|
||||
)
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@@ -659,11 +664,31 @@ class GPUModelRunner(
|
||||
|
||||
# Persistent buffers for CUDA graphs.
|
||||
self.input_ids = self._make_buffer(self.max_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
||||
self.positions = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.query_start_loc = self._make_buffer(
|
||||
self.max_num_reqs + 1, dtype=torch.int32
|
||||
)
|
||||
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
self.seq_lens = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.optimistic_seq_lens_cpu = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, pin_memory=self.pin_memory
|
||||
)
|
||||
self.num_computed_tokens = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.prev_num_draft_tokens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
self.req_indices = self._make_buffer(self.max_num_tokens, dtype=torch.int64)
|
||||
# Maps current batch position -> previous batch position (-1 for new reqs)
|
||||
self.prev_positions = self._make_buffer(self.max_num_reqs, dtype=torch.int64)
|
||||
self.num_scheduled_tokens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
|
||||
self.encoder_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens = self._make_buffer(
|
||||
@@ -683,7 +708,7 @@ class GPUModelRunner(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
self.num_accepted_tokens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int64
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Only relevant for multimodal models
|
||||
@@ -722,12 +747,14 @@ class GPUModelRunner(
|
||||
# None in the first PP rank. The rest are set after load_model.
|
||||
self.intermediate_tensors: IntermediateTensors | None = None
|
||||
|
||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||
# Keep in int64 to avoid overflow with long context
|
||||
self.arange_np = np.arange(
|
||||
max(self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
|
||||
dtype=np.int64,
|
||||
)
|
||||
# OPTIMIZATION: Cache the arange tensors rather than creating them
|
||||
# every step. Keep in int64 to avoid overflow with long context.
|
||||
# - arange_np: immutable [0, 1, 2, ...] used as source for batched computation
|
||||
# - query_pos: CpuGpuBuffer for the computed batched arange result
|
||||
arange_size = max(self.max_num_reqs + 1, self.max_num_tokens)
|
||||
self.arange_np = np.arange(arange_size, dtype=np.int64)
|
||||
self.query_pos = self._make_buffer(arange_size, dtype=torch.int64)
|
||||
self._arange_scratch = np.empty(arange_size, dtype=np.int64)
|
||||
|
||||
# Layer pairings for cross-layer KV sharing.
|
||||
# If an Attention layer `layer_name` is in the keys of this dict, it
|
||||
@@ -812,7 +839,7 @@ class GPUModelRunner(
|
||||
self.valid_sampled_token_count_copy_stream = torch.cuda.Stream()
|
||||
self.valid_sampled_token_count_cpu = torch.empty(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
@@ -903,13 +930,13 @@ class GPUModelRunner(
|
||||
return self.mrope_positions.gpu[:, :num_tokens]
|
||||
if self.uses_xdrope_dim > 0:
|
||||
return self.xdrope_positions.gpu[:, :num_tokens]
|
||||
return self.positions.gpu[:num_tokens]
|
||||
return self.positions[:num_tokens]
|
||||
else:
|
||||
if self.uses_mrope:
|
||||
return self.mrope_positions.gpu[:, num_tokens]
|
||||
if self.uses_xdrope_dim > 0:
|
||||
return self.xdrope_positions.gpu[:, num_tokens]
|
||||
return self.positions.gpu[num_tokens]
|
||||
return self.positions[num_tokens]
|
||||
|
||||
def _make_buffer(
|
||||
self, *size: int | torch.SymInt, dtype: torch.dtype, numpy: bool = True
|
||||
@@ -953,7 +980,7 @@ class GPUModelRunner(
|
||||
if len(token_type_id_requests) == 0:
|
||||
return model_kwargs
|
||||
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
token_type_ids = []
|
||||
|
||||
for i in range(num_reqs):
|
||||
@@ -1021,7 +1048,7 @@ class GPUModelRunner(
|
||||
def _sync_device(self) -> None:
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> Callable | None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
|
||||
@@ -1086,6 +1113,8 @@ class GPUModelRunner(
|
||||
ngram_gpu_new_reqs: list[CachedRequestState] = []
|
||||
|
||||
reqs_to_add: list[CachedRequestState] = []
|
||||
deferred_spec_decode_corrections = []
|
||||
|
||||
# Add new requests to the cached states.
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
@@ -1172,10 +1201,8 @@ class GPUModelRunner(
|
||||
scheduler_output,
|
||||
self.input_batch.req_id_to_index,
|
||||
)
|
||||
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
# then use it to update actual num_computed_tokens of each request.
|
||||
valid_sampled_token_count = self._get_valid_sampled_token_count()
|
||||
if self.use_async_spec_decode:
|
||||
self.prev_num_draft_tokens.np.fill(0)
|
||||
|
||||
for i, req_id in enumerate(req_data.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
@@ -1202,15 +1229,30 @@ class GPUModelRunner(
|
||||
if req_index is None:
|
||||
req_state.prev_num_draft_len = 0
|
||||
else:
|
||||
assert self.input_batch.prev_req_id_to_index is not None
|
||||
prev_req_index = self.input_batch.prev_req_id_to_index[req_id]
|
||||
num_accepted = valid_sampled_token_count[prev_req_index] - 1
|
||||
num_rejected = req_state.prev_num_draft_len - num_accepted
|
||||
num_computed_tokens -= num_rejected
|
||||
req_state.output_token_ids.extend([-1] * num_accepted)
|
||||
# Optimistically assume all accepted; queue up a correction
|
||||
# to be called after the model forward to preserve async
|
||||
# scheduling. Corrected on GPU in _prepare_inputs.
|
||||
optimistic_num_accepted = req_state.prev_num_draft_len
|
||||
req_state.output_token_ids.extend([-1] * optimistic_num_accepted)
|
||||
|
||||
if is_ngram_gpu and num_accepted > 0 and req_index is not None:
|
||||
self.input_batch.num_tokens_no_spec[req_index] += num_accepted
|
||||
deferred_spec_decode_corrections.append(
|
||||
(req_id, optimistic_num_accepted, req_state)
|
||||
)
|
||||
|
||||
prev_req_index = (
|
||||
self.input_batch.prev_req_id_to_index.get(req_id)
|
||||
if self.input_batch.prev_req_id_to_index
|
||||
else None
|
||||
)
|
||||
if prev_req_index is not None:
|
||||
self.prev_num_draft_tokens.np[prev_req_index] = (
|
||||
optimistic_num_accepted
|
||||
)
|
||||
|
||||
if is_ngram_gpu and optimistic_num_accepted > 0:
|
||||
self.input_batch.num_tokens_no_spec[req_index] += (
|
||||
optimistic_num_accepted
|
||||
)
|
||||
|
||||
# Update the cached states.
|
||||
req_state.num_computed_tokens = num_computed_tokens
|
||||
@@ -1238,7 +1280,8 @@ class GPUModelRunner(
|
||||
)
|
||||
elif num_output_tokens < len(req_state.output_token_ids):
|
||||
# Some output tokens were discarded due to a sync-KV-load
|
||||
# failure. Align the cached state.
|
||||
# failure, or output_token_ids was inflated by the optimistic
|
||||
# extend above (async spec decode). Align the cached state.
|
||||
del req_state.output_token_ids[num_output_tokens:]
|
||||
if req_index is not None:
|
||||
end_idx = (
|
||||
@@ -1326,6 +1369,40 @@ class GPUModelRunner(
|
||||
_pinned_val_buf=self._ngram_pinned_val_buf,
|
||||
)
|
||||
|
||||
if deferred_spec_decode_corrections:
|
||||
|
||||
def correct_spec_decode_token_counts():
|
||||
valid_sampled_token_count = self._get_valid_sampled_token_count()
|
||||
if not valid_sampled_token_count:
|
||||
return
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
if not prev_req_id_to_index:
|
||||
return
|
||||
for (
|
||||
req_id,
|
||||
optimistic_num_accepted,
|
||||
req_state,
|
||||
) in deferred_spec_decode_corrections:
|
||||
prev_req_index = prev_req_id_to_index.get(req_id)
|
||||
if prev_req_index is None:
|
||||
continue
|
||||
num_accepted = valid_sampled_token_count[prev_req_index] - 1
|
||||
correction = optimistic_num_accepted - num_accepted
|
||||
req_state.num_computed_tokens -= correction
|
||||
cur_req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
if cur_req_index is None:
|
||||
continue
|
||||
self.input_batch.num_computed_tokens_cpu[cur_req_index] -= (
|
||||
correction
|
||||
)
|
||||
if is_ngram_gpu and correction > 0:
|
||||
self.input_batch.num_tokens_no_spec[cur_req_index] -= correction
|
||||
self.num_tokens_no_spec_gpu[cur_req_index] -= correction
|
||||
|
||||
return correct_spec_decode_token_counts
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_states_after_model_execute(
|
||||
self, output_token_ids: torch.Tensor, scheduler_output: "SchedulerOutput"
|
||||
) -> None:
|
||||
@@ -1340,6 +1417,9 @@ class GPUModelRunner(
|
||||
if not self.speculative_config or not self.model_config.is_hybrid:
|
||||
return
|
||||
|
||||
# TODO: Remove .cpu() sync to enable fully async for hybrid model;
|
||||
# Use num_computed_tokens.gpu instead of req.num_computed_tokens to
|
||||
# support aligned mamba cache mode.
|
||||
# Find the number of accepted tokens for each sequence.
|
||||
num_reqs = output_token_ids.size(0)
|
||||
self.num_accepted_tokens.gpu[:num_reqs] = (
|
||||
@@ -1486,12 +1566,14 @@ class GPUModelRunner(
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
num_tokens: np.ndarray,
|
||||
arange_out: np.ndarray,
|
||||
cumsum_dtype: np.dtype | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
) -> np.ndarray:
|
||||
"""Get the cumulative sum and batched arange of the given array.
|
||||
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
|
||||
# Equivalent to but faster than:
|
||||
# np.concatenate([np.arange(n) for n in num_tokens])
|
||||
E.g., [2, 5, 3] -> [2, 7, 10], arange written to
|
||||
arange_out[:10] as [0, 1, 0, 1, 2, 3, 4, 0, 1, 2].
|
||||
Equivalent to but faster than:
|
||||
np.concatenate([np.arange(n) for n in num_tokens])
|
||||
"""
|
||||
# Step 1. [2, 5, 3] -> [2, 7, 10]
|
||||
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
|
||||
@@ -1499,13 +1581,33 @@ class GPUModelRunner(
|
||||
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
|
||||
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
|
||||
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
|
||||
np.subtract(
|
||||
self.arange_np[:total_num_tokens],
|
||||
cumsums_offsets,
|
||||
out=arange_out[:total_num_tokens],
|
||||
)
|
||||
|
||||
return cu_num_tokens, arange
|
||||
return cu_num_tokens
|
||||
|
||||
def _compute_prev_positions(self, num_reqs: int) -> None:
|
||||
"""Build prev_positions mapping: current pos -> previous pos (-1 if new).
|
||||
|
||||
Populates self.prev_positions.np[:num_reqs] with the mapping.
|
||||
"""
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
prev_positions = self.prev_positions.np[:num_reqs]
|
||||
|
||||
if not prev_req_id_to_index:
|
||||
prev_positions.fill(-1)
|
||||
return
|
||||
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
prev_positions[i] = prev_req_id_to_index.get(req_id, -1)
|
||||
|
||||
def _prepare_input_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_reqs: int,
|
||||
total_num_scheduled_tokens: int,
|
||||
cu_num_tokens: np.ndarray,
|
||||
) -> None:
|
||||
@@ -1513,7 +1615,11 @@ class GPUModelRunner(
|
||||
|
||||
Carefully handles the `prev_sampled_token_ids` which can be cached
|
||||
from the previous engine iteration, in which case those tokens on the
|
||||
GPU need to be copied into the corresponding slots into input_ids."""
|
||||
GPU need to be copied into the corresponding slots into input_ids.
|
||||
|
||||
Uses self.prev_positions[:num_reqs] which maps current pos -> prev pos
|
||||
(-1 for new requests).
|
||||
"""
|
||||
|
||||
if self.input_batch.prev_sampled_token_ids is None:
|
||||
# Normal scheduling case
|
||||
@@ -1526,47 +1632,50 @@ class GPUModelRunner(
|
||||
# Async scheduling case, where some decode requests from the previous
|
||||
# iteration won't have entries in input_ids_cpu and need to be copied
|
||||
# on the GPU from prev_sampled_token_ids.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
assert prev_req_id_to_index is not None
|
||||
prev_positions = self.prev_positions.np[:num_reqs]
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
sample_flattened_indices: list[int] = []
|
||||
spec_flattened_indices: list[int] = []
|
||||
prev_common_req_indices: list[int] = []
|
||||
prev_draft_token_indices: list[int] = []
|
||||
indices_match = True
|
||||
prev_indices: list[int] = []
|
||||
common_indices_match = True
|
||||
max_flattened_index = -1
|
||||
total_num_spec_tokens = 0
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
for req_id, cur_index in self.input_batch.req_id_to_index.items():
|
||||
if (prev_index := prev_req_id_to_index.get(req_id)) is not None:
|
||||
prev_common_req_indices.append(prev_index)
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
|
||||
# sample_flattened_indices = [0, 2, 5]
|
||||
# spec_flattened_indices = [1, 3, 4, 6, 7]
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(
|
||||
range(flattened_index - draft_len + 1, flattened_index + 1)
|
||||
)
|
||||
start = prev_index * self.num_spec_tokens
|
||||
# prev_draft_token_indices is used to find which draft_tokens_id
|
||||
# should be copied to input_ids
|
||||
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
|
||||
# flatten draft_tokens_id [1,2,3,4,5,6]
|
||||
# draft_len of each request [1, 2, 1]
|
||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||
indices_match &= prev_index == flattened_index
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
for cur_index in range(num_reqs):
|
||||
prev_index = prev_positions[cur_index]
|
||||
if prev_index < 0:
|
||||
continue
|
||||
prev_indices.append(prev_index)
|
||||
req_id = self.input_batch.req_ids[cur_index]
|
||||
# We need to compute the flattened input_ids index of the
|
||||
# last token in each common request.
|
||||
draft_len = len(scheduled_spec_tokens.get(req_id, ()))
|
||||
total_num_spec_tokens += draft_len
|
||||
flattened_index = cu_num_tokens[cur_index].item() - 1
|
||||
# example: cu_num_tokens = [2, 5, 8], draft_tokens = [1, 2, 2]
|
||||
# sample_flattened_indices = [0, 2, 5]
|
||||
# spec_flattened_indices = [1, 3, 4, 6, 7]
|
||||
sample_flattened_indices.append(flattened_index - draft_len)
|
||||
spec_flattened_indices.extend(
|
||||
range(flattened_index - draft_len + 1, flattened_index + 1)
|
||||
)
|
||||
start = prev_index * self.num_spec_tokens
|
||||
# prev_draft_token_indices is used to find which draft_tokens_id
|
||||
# should be copied to input_ids
|
||||
# example: prev draft_tokens_id [[1,2], [3,4], [5, 6]]
|
||||
# flatten draft_tokens_id [1,2,3,4,5,6]
|
||||
# draft_len of each request [1, 2, 1]
|
||||
# then prev_draft_token_indices is [0, 2, 3, 4]
|
||||
prev_draft_token_indices.extend(range(start, start + draft_len))
|
||||
common_indices_match &= prev_index == flattened_index
|
||||
max_flattened_index = max(max_flattened_index, flattened_index)
|
||||
|
||||
num_common_tokens = len(sample_flattened_indices)
|
||||
total_without_spec = total_num_scheduled_tokens - total_num_spec_tokens
|
||||
if num_common_tokens < total_without_spec:
|
||||
# If not all requests are decodes from the last iteration,
|
||||
# We need to copy the input_ids_cpu to the GPU first.
|
||||
# we need to copy the input_ids_cpu to the GPU first.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if self.enable_prompt_embeds:
|
||||
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||
@@ -1575,7 +1684,7 @@ class GPUModelRunner(
|
||||
# No requests in common with the previous iteration
|
||||
# So input_ids.cpu will have all the input ids.
|
||||
return
|
||||
if indices_match and max_flattened_index == (num_common_tokens - 1):
|
||||
if common_indices_match and max_flattened_index == (num_common_tokens - 1):
|
||||
# Common-case optimization: the batch is unchanged
|
||||
# and no reordering happened.
|
||||
# The indices are both the same permutation of 0..N-1 so
|
||||
@@ -1592,7 +1701,7 @@ class GPUModelRunner(
|
||||
sample_flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
prev_common_req_indices_tensor = torch.tensor(
|
||||
prev_common_req_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
prev_indices, dtype=torch.int64, pin_memory=self.pin_memory
|
||||
).to(self.device, non_blocking=True)
|
||||
self.input_ids.gpu.scatter_(
|
||||
dim=0,
|
||||
@@ -1696,15 +1805,15 @@ class GPUModelRunner(
|
||||
req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
|
||||
|
||||
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
|
||||
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens, arange = self._get_cumsum_and_arange(num_scheduled_tokens)
|
||||
# self.query_pos.np[:10]: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
cu_num_tokens = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens, self.query_pos.np
|
||||
)
|
||||
|
||||
# Get positions.
|
||||
positions_np = self.positions.np[:total_num_scheduled_tokens]
|
||||
np.add(
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices],
|
||||
arange,
|
||||
out=positions_np,
|
||||
positions_np = (
|
||||
self.input_batch.num_computed_tokens_cpu[req_indices]
|
||||
+ self.query_pos.np[: cu_num_tokens[-1]]
|
||||
)
|
||||
|
||||
# Calculate M-RoPE positions.
|
||||
@@ -1782,9 +1891,6 @@ class GPUModelRunner(
|
||||
|
||||
output_idx += num_sched
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
|
||||
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc.np[0] = 0
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
|
||||
@@ -1794,12 +1900,21 @@ class GPUModelRunner(
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||
|
||||
self.seq_lens.np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
|
||||
# Compute optimistic seq_lens (assumes all draft tokens from previous
|
||||
# iteration accepted). Store in optimistic_seq_lens_cpu for use by
|
||||
# _build_attention_metadata (max_seq_len) and discard_request_mask.
|
||||
# seq_lens (GPU) will be computed later using the same optimistic values.
|
||||
torch.add(
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
torch.from_numpy(num_scheduled_tokens),
|
||||
out=self.optimistic_seq_lens_cpu[:num_reqs],
|
||||
)
|
||||
# Fill unused with 0 for full cuda graph mode.
|
||||
self.seq_lens.np[num_reqs:].fill(0)
|
||||
self.seq_lens.copy_to_gpu()
|
||||
self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
|
||||
|
||||
# Build prev_positions mapping: current pos -> prev pos (-1 if new).
|
||||
# Used for gathering from previous iteration's GPU tensors.
|
||||
prev_req_id_to_index = self.input_batch.prev_req_id_to_index
|
||||
self._compute_prev_positions(num_reqs)
|
||||
|
||||
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
@@ -1807,13 +1922,78 @@ class GPUModelRunner(
|
||||
# Record which requests should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning
|
||||
self.discard_request_mask.np[:num_reqs] = (
|
||||
self.seq_lens.np[:num_reqs] < num_tokens_np
|
||||
self.optimistic_seq_lens_cpu[:num_reqs].numpy() < num_tokens_np
|
||||
)
|
||||
self.discard_request_mask.copy_to_gpu(num_reqs)
|
||||
|
||||
# Sync num_accepted_tokens from CPU (set by
|
||||
# _update_states_after_model_execute for hybrid models).
|
||||
if self.num_accepted_tokens_event is not None:
|
||||
self.num_accepted_tokens_event.synchronize()
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
else:
|
||||
self.num_accepted_tokens.np.fill(1)
|
||||
self.num_accepted_tokens.gpu.fill_(1)
|
||||
|
||||
# Update num_computed_tokens on GPU. In async spec decode,
|
||||
# CPU values are optimistic (all drafts accepted). The kernel
|
||||
# corrects on GPU using the previous step's
|
||||
# valid_sampled_token_count_gpu. Otherwise, just copy from CPU.
|
||||
if (
|
||||
self.use_async_spec_decode
|
||||
and self.valid_sampled_token_count_gpu is not None
|
||||
and prev_req_id_to_index
|
||||
):
|
||||
self.prev_positions.copy_to_gpu(num_reqs)
|
||||
self.prev_num_draft_tokens.copy_to_gpu()
|
||||
cpu_values = self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs].to(
|
||||
device=self.device, non_blocking=True
|
||||
)
|
||||
update_num_computed_tokens_for_batch_change(
|
||||
self.num_computed_tokens,
|
||||
self.num_accepted_tokens.gpu[:num_reqs],
|
||||
self.prev_positions.gpu[:num_reqs],
|
||||
self.valid_sampled_token_count_gpu,
|
||||
self.prev_num_draft_tokens.gpu,
|
||||
cpu_values,
|
||||
)
|
||||
else:
|
||||
self.num_computed_tokens[:num_reqs].copy_(
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
self.req_indices.np[:total_num_scheduled_tokens] = req_indices
|
||||
self.req_indices.copy_to_gpu(total_num_scheduled_tokens)
|
||||
req_indices_gpu = self.req_indices.gpu[:total_num_scheduled_tokens]
|
||||
|
||||
self.query_pos.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.num_scheduled_tokens.np[:num_reqs] = num_scheduled_tokens
|
||||
self.num_scheduled_tokens.copy_to_gpu(num_reqs)
|
||||
num_scheduled_tokens_gpu = self.num_scheduled_tokens.gpu[:num_reqs]
|
||||
self.positions[:total_num_scheduled_tokens] = (
|
||||
self.num_computed_tokens[req_indices_gpu].to(torch.int64)
|
||||
+ self.query_pos.gpu[:total_num_scheduled_tokens]
|
||||
)
|
||||
self.seq_lens[:num_reqs] = (
|
||||
self.num_computed_tokens[:num_reqs] + num_scheduled_tokens_gpu
|
||||
)
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
num_reqs,
|
||||
self.query_start_loc.gpu[: num_reqs + 1],
|
||||
self.positions[:total_num_scheduled_tokens],
|
||||
)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self._prepare_input_ids(
|
||||
scheduler_output,
|
||||
num_reqs,
|
||||
total_num_scheduled_tokens,
|
||||
cu_num_tokens,
|
||||
)
|
||||
@@ -1830,9 +2010,14 @@ class GPUModelRunner(
|
||||
self.xdrope_positions.cpu[:, :total_num_scheduled_tokens],
|
||||
non_blocking=True,
|
||||
)
|
||||
else:
|
||||
# Common case (1D positions)
|
||||
self.positions.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if self.use_async_spec_decode and (self.uses_mrope or self.uses_xdrope_dim > 0):
|
||||
drift = self.num_computed_tokens[req_indices_gpu].to(
|
||||
torch.int64
|
||||
) - self.input_batch.num_computed_tokens_cpu_tensor[req_indices].to(
|
||||
device=self.device, dtype=torch.int64, non_blocking=True
|
||||
)
|
||||
target = self.mrope_positions if self.uses_mrope else self.xdrope_positions
|
||||
target.gpu[:, :total_num_scheduled_tokens] += drift
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
if not use_spec_decode:
|
||||
@@ -1857,12 +2042,13 @@ class GPUModelRunner(
|
||||
draft_token_ids,
|
||||
) in scheduler_output.scheduled_spec_decode_tokens.items():
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
draft_len = len(draft_token_ids)
|
||||
num_draft_tokens[req_idx] = draft_len
|
||||
if (
|
||||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]
|
||||
):
|
||||
num_decode_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
num_decode_draft_tokens[req_idx] = draft_len
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens
|
||||
)
|
||||
@@ -1924,16 +2110,7 @@ class GPUModelRunner(
|
||||
# window size when capturing to make sure the correct kernel is selected.
|
||||
max_seq_len = self.max_model_len
|
||||
else:
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
if use_spec_decode:
|
||||
if self.num_accepted_tokens_event is not None:
|
||||
self.num_accepted_tokens_event.synchronize()
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
max_seq_len = self.optimistic_seq_lens_cpu.numpy()[:num_reqs].max().item()
|
||||
|
||||
kv_cache_groups = self.kv_cache_config.kv_cache_groups
|
||||
|
||||
@@ -1963,22 +2140,29 @@ class GPUModelRunner(
|
||||
attn_gid = self.routed_experts_attn_gid
|
||||
slot_mapping_attn = slot_mappings[attn_gid]
|
||||
self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy()
|
||||
# Compute is_prefilling: True if request is still in prefill phase
|
||||
# (num_computed_tokens < num_prompt_tokens). Used by mamba backends to
|
||||
# distinguish actual decodes from short extends.
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs_padded
|
||||
]
|
||||
num_prompt_tokens_cpu = self.input_batch.num_prompt_tokens_cpu_tensor[
|
||||
:num_reqs_padded
|
||||
]
|
||||
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs_padded]
|
||||
|
||||
# is_prefilling: True if request is still in prefill phase.
|
||||
# Used by mamba backends to distinguish actual decodes from
|
||||
# short extends.
|
||||
is_prefilling = num_computed_tokens_cpu < num_prompt_tokens_cpu
|
||||
|
||||
if self.use_async_spec_decode:
|
||||
# GPU tensors are authoritative in async mode.
|
||||
seq_lens_cpu = None
|
||||
num_computed_tokens_cpu = None
|
||||
|
||||
cm_base = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
|
||||
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
|
||||
seq_lens=self.seq_lens.gpu[:num_reqs_padded],
|
||||
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
|
||||
seq_lens=self.seq_lens[:num_reqs_padded],
|
||||
_seq_lens_cpu=seq_lens_cpu,
|
||||
_num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs_padded,
|
||||
num_actual_tokens=num_tokens_padded,
|
||||
@@ -1992,7 +2176,7 @@ class GPUModelRunner(
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
|
||||
self.seq_lens.cpu[:num_reqs],
|
||||
self.optimistic_seq_lens_cpu[:num_reqs],
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
@@ -2396,33 +2580,34 @@ class GPUModelRunner(
|
||||
# [4, 1, 3, 1, 2]
|
||||
num_sampled_tokens = num_draft_tokens + 1
|
||||
|
||||
# Step 1. cu_num_sampled_tokens: [4, 5, 8, 9, 11]
|
||||
# arange: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
cu_num_sampled_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_sampled_tokens, cumsum_dtype=np.int32
|
||||
# Step 1.
|
||||
# cu_num_sampled_tokens: [4, 5, 8, 9, 11]
|
||||
# _arange_scratch[:11]: [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
cu_num_sampled_tokens = self._get_cumsum_and_arange(
|
||||
num_sampled_tokens, self._arange_scratch, cumsum_dtype=np.int32
|
||||
)
|
||||
# Step 2. [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_indices = np.repeat(
|
||||
cu_num_scheduled_tokens - num_sampled_tokens, num_sampled_tokens
|
||||
)
|
||||
# Step 3. [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
logits_indices += arange
|
||||
logits_indices += self._arange_scratch[: cu_num_sampled_tokens[-1]]
|
||||
|
||||
# Compute the bonus logits indices.
|
||||
bonus_logits_indices = cu_num_sampled_tokens - 1
|
||||
|
||||
# Compute the draft logits indices.
|
||||
# cu_num_draft_tokens: [3, 3, 5, 5, 6]
|
||||
# arange: [0, 1, 2, 0, 1, 0]
|
||||
cu_num_draft_tokens, arange = self._get_cumsum_and_arange(
|
||||
num_draft_tokens, cumsum_dtype=np.int32
|
||||
# _arange_scratch[:6]: [0, 1, 2, 0, 1, 0]
|
||||
cu_num_draft_tokens = self._get_cumsum_and_arange(
|
||||
num_draft_tokens, self._arange_scratch, cumsum_dtype=np.int32
|
||||
)
|
||||
# [0, 0, 0, 5, 5, 9]
|
||||
target_logits_indices = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_draft_tokens
|
||||
)
|
||||
# [0, 1, 2, 5, 6, 9]
|
||||
target_logits_indices += arange
|
||||
target_logits_indices += self._arange_scratch[: cu_num_draft_tokens[-1]]
|
||||
|
||||
# TODO: Optimize the CPU -> GPU copy.
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
@@ -2924,7 +3109,7 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||
seq_lens_cpu = self.optimistic_seq_lens_cpu[:num_reqs]
|
||||
|
||||
pooling_metadata = self.input_batch.get_pooling_metadata()
|
||||
pooling_metadata.build_pooling_cursor(
|
||||
@@ -3083,9 +3268,9 @@ class GPUModelRunner(
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
positions = self.xdrope_positions.gpu[:, :num_input_tokens]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
if num_input_tokens > num_scheduled_tokens:
|
||||
self.positions.gpu[num_scheduled_tokens:num_input_tokens].zero_()
|
||||
self.positions[num_scheduled_tokens:num_input_tokens].zero_()
|
||||
|
||||
if is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@@ -3610,7 +3795,7 @@ class GPUModelRunner(
|
||||
self.synchronize_input_prep(),
|
||||
):
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
deferred_state_corrections_fn = self._update_states(scheduler_output)
|
||||
|
||||
if has_ec_transfer() and not get_ec_transfer().is_consumer:
|
||||
with self.maybe_get_ec_connector_output(
|
||||
@@ -3723,6 +3908,12 @@ class GPUModelRunner(
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
if self.cache_config.mamba_cache_mode == "align":
|
||||
# preprocess_mamba reads req_state.num_computed_tokens (CPU)
|
||||
# to decide copy operations, so we must apply deferred
|
||||
# corrections before it runs.
|
||||
if deferred_state_corrections_fn:
|
||||
deferred_state_corrections_fn()
|
||||
deferred_state_corrections_fn = None
|
||||
mamba_utils.preprocess_mamba(
|
||||
scheduler_output,
|
||||
self.kv_cache_config,
|
||||
@@ -3734,6 +3925,14 @@ class GPUModelRunner(
|
||||
self.model.get_mamba_state_copy_func(),
|
||||
self._get_mamba_copy_bufs(),
|
||||
)
|
||||
# preprocess_mamba resets num_accepted_tokens_cpu to 1
|
||||
# for requests whose state was copied to a new block.
|
||||
# Re-sync to GPU so the mamba kernel reads from the
|
||||
# correct initial state slot (init_token_idx = 0).
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs]
|
||||
)
|
||||
self.num_accepted_tokens.copy_to_gpu(num_reqs)
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
|
||||
@@ -3894,6 +4093,12 @@ class GPUModelRunner(
|
||||
slot_mappings,
|
||||
)
|
||||
self.kv_connector_output = kv_connector_output
|
||||
|
||||
# Now the batch has been launched we can wait for corrections from the
|
||||
# previous model forward without breaking async scheduling.
|
||||
if deferred_state_corrections_fn:
|
||||
deferred_state_corrections_fn()
|
||||
|
||||
return None
|
||||
|
||||
@torch.inference_mode
|
||||
@@ -3958,6 +4163,7 @@ class GPUModelRunner(
|
||||
|
||||
self._draft_token_ids = None
|
||||
self._draft_token_req_ids = None
|
||||
self.valid_sampled_token_count_gpu = None
|
||||
self.input_batch.prev_sampled_token_ids = None
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
@@ -4002,7 +4208,7 @@ class GPUModelRunner(
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
spec_decode_common_attn_metadata,
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4237,6 +4443,9 @@ class GPUModelRunner(
|
||||
counts_cpu[: counts.shape[0]].copy_(counts, non_blocking=True)
|
||||
self.valid_sampled_token_count_event.record()
|
||||
|
||||
if self.use_async_spec_decode:
|
||||
# Stash for GPU-side correction in _prepare_inputs.
|
||||
self.valid_sampled_token_count_gpu = valid_sampled_tokens_count
|
||||
self.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(1)
|
||||
|
||||
def _get_valid_sampled_token_count(self) -> list[int]:
|
||||
@@ -4366,7 +4575,7 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -4405,7 +4614,7 @@ class GPUModelRunner(
|
||||
)
|
||||
next_token_ids, valid_sampled_tokens_count = (
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
self.optimistic_seq_lens_cpu,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
@@ -5148,14 +5357,19 @@ class GPUModelRunner(
|
||||
# In the mixed batch mode (used for FI warmup), we use
|
||||
# shorter sequence lengths to run faster.
|
||||
# TODO(luka) better system for describing dummy batches
|
||||
seq_lens = [1] * num_decode_tokens + [num_prefill_tokens + 1] # type: ignore[assignment]
|
||||
seq_lens = torch.tensor( # type: ignore[assignment]
|
||||
[1] * num_decode_tokens + [num_prefill_tokens + 1],
|
||||
dtype=torch.int,
|
||||
)
|
||||
else:
|
||||
seq_lens = max_query_len # type: ignore[assignment]
|
||||
self.seq_lens.np[:num_reqs] = seq_lens
|
||||
self.seq_lens.np[num_reqs:] = 0
|
||||
self.seq_lens.copy_to_gpu()
|
||||
self.optimistic_seq_lens_cpu[:num_reqs] = seq_lens
|
||||
self.optimistic_seq_lens_cpu[num_reqs:].fill_(0)
|
||||
self.seq_lens.copy_(self.optimistic_seq_lens_cpu, non_blocking=True)
|
||||
|
||||
cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens)
|
||||
cum_num_tokens = self._get_cumsum_and_arange(
|
||||
num_scheduled_tokens, self.query_pos.np
|
||||
)
|
||||
self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
@@ -5201,7 +5415,7 @@ class GPUModelRunner(
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_tokens_padded]
|
||||
positions = self.positions[:num_tokens_padded]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
|
||||
Reference in New Issue
Block a user