[Model Runner V2] Support spec decoding [1/N] (#29274)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-23 10:09:06 -08:00
committed by GitHub
parent 7f12c82fa6
commit b004c00418
5 changed files with 347 additions and 26 deletions

View File

@@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
InputBatch,
InputBuffers,
combine_sampled_and_draft_tokens,
post_update,
prepare_pos_seq_lens,
prepare_prefill_inputs,
update_num_computed_tokens,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_prep_event = None
self.structured_outputs_event = None
if self.speculative_config is not None:
self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
else:
self.do_spec_decode = False
self.num_speculative_steps = 0
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
pin_memory=self.pin_memory,
@@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
idx_mapping_np = idx_mapping.np[:num_reqs]
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
# Get the number of draft tokens for each request.
if not scheduler_output.scheduled_spec_decode_tokens:
# No draft token scheduled (common case).
total_num_draft_tokens = 0
total_num_logits = num_reqs
cu_num_logits = torch.arange(
num_reqs + 1, device=self.device, dtype=torch.int32
)
else:
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
num_draft_tokens = np.array(
[
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
for req_id in req_ids
],
dtype=np.int32,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
np.cumsum(
num_draft_tokens + 1,
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
)
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
@@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens = self.input_buffers.seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens.
combine_sampled_and_draft_tokens(
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids.gpu,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc_gpu,
seq_lens,
self.req_states.prefill_len.gpu,
self.req_states.draft_tokens,
cu_num_logits,
total_num_logits,
)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
@@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
)
# Logits indices to sample next token from.
logits_indices = query_start_loc_gpu[1:] - 1
# Get num_computed_tokens.
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
# num_computed_tokens_cpu. This works for most cases.
@@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc_gpu,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
@@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
attn_metadata=attn_metadata,
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
)
def sample(
@@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
# TODO(woosuk): Make compatible with spec decoding.
assert input_batch.num_draft_tokens == 0
with async_barrier(self.structured_outputs_event):
apply_grammar_bitmask(
logits,
@@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_buffers,
)
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
# Get the number of sampled tokens.
# 0 if chunked-prefilling, 1 if not.
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
is_chunked_prefilling = input_batch.seq_lens < prefill_len
num_sampled = (~is_chunked_prefilling).int()
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int()
else:
# Draft tokens for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
input_ids,
input_batch.cu_num_logits,
self.num_speculative_steps,
)
num_sampled *= ~is_chunked_prefilling
sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
return sampler_output, num_sampled
def compute_prompt_logprobs(
@@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_sampled: torch.Tensor,
) -> None:
# Update the number of computed tokens.
update_num_computed_tokens(
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens,
self.req_states.last_sampled_tokens,
sampled_tokens,
num_sampled,
input_batch.query_start_loc,
input_batch.cu_num_logits,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
# TODO(woosuk): Simplify this.
@@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.prefill_len.np[idx_mapping_np],
)
# Store the last sampled token ids.
last_sampled = sampled_tokens
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
def get_cudagraph_and_dp_padding(
self,
scheduler_output: SchedulerOutput,
@@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.idx_mapping_np, pos
)
if input_batch.num_draft_tokens > 0:
sampling_metadata = self.req_states.expand_sampling_metadata(
sampling_metadata, input_batch.cu_num_logits
)
if self.lora_config:
# Activate LoRA adapters.