[V1][Spec Decode] Ngram Spec Decode (#12193)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2025-02-15 18:05:11 -08:00
committed by GitHub
parent 367cb8ce8c
commit 80f63a3966
21 changed files with 1023 additions and 82 deletions

View File

@@ -390,6 +390,7 @@ class InputBatch:
def make_sampling_metadata(
self,
req_id_output_token_ids: Dict[str, List[int]],
req_id_to_spec_token_ids: Dict[str, List[int]],
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
@@ -423,7 +424,8 @@ class InputBatch:
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
output_token_ids: List[List[int]] = []
spec_token_ids: List[List[int]] = []
rejection_sampling = False
for req_id in self.req_ids[:self.num_reqs]:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
@@ -434,11 +436,18 @@ class InputBatch:
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
spec_token_ids.append(req_spec_token_ids)
if req_spec_token_ids:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling = True
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
rejection_sampling=rejection_sampling,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
min_p=self.min_p[:self.num_reqs],
@@ -452,6 +461,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
output_token_ids=output_token_ids,
spec_token_ids=spec_token_ids,
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,

View File

@@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len,
self.max_num_tokens),
dtype=np.int32)
self.arange_cpu = torch.from_numpy(self.arange_np)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
@@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return batch_changed
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens_list: List[int] = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
all_spec_token_ids: List[int] = []
num_spec_tokens_list: List[int] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens_list.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
all_spec_token_ids.extend(spec_token_ids)
num_spec_tokens_list.append(len(spec_token_ids))
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0
@@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
use_spec_decode = len(all_spec_token_ids) > 0
if use_spec_decode:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
num_spec_tokens_list)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets = np.concatenate(
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets = np.repeat(
self.input_batch.num_computed_tokens_cpu[:num_reqs],
num_spec_tokens_list) + spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets = (
spec_seq_offsets +
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
torch.int64)
all_spec_token_ids = torch.tensor(all_spec_token_ids,
device="cpu",
dtype=self.input_ids_cpu.dtype)
# Step 2. Write spec token ids to input_ids_cpu.
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
0, cumsums_spec_offsets, all_spec_token_ids)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
num_sampled_tokens = num_spec_tokens_np + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
@@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens=suffix_kv_lens,
)
if use_spec_decode:
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
self.device, non_blocking=True)
else:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _compute_cascade_attn_prefix_len(
@@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_sampling(
self,
batch_changed: bool,
req_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Create the sampling metadata.
req_id_output_token_ids: Dict[str, List[int]] = \
@@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy=not batch_changed)
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(batch_changed)
sampling_metadata = self._prepare_sampling(
batch_changed, scheduler_output.scheduled_spec_decode_tokens)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
@@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate( # type: ignore[assignment]
self.input_batch.req_ids[:num_reqs]):
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
self.input_batch.num_tokens[i] += 1
# OPTIMIZATION: Priming the state updates for later updates.
req_state.output_token_ids.append(0)
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
@@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
@@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
)
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id
# Update batch with the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
else:
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)

View File

@@ -695,7 +695,7 @@ class TPUModelRunner:
model_runner_output = ModelRunnerOutput(
req_ids=all_req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
)