[Model Runner V2] Support spec decoding [1/N] (#29274)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_sampled_and_draft_tokens,
|
||||
post_update,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
update_num_computed_tokens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.do_spec_decode = False
|
||||
self.num_speculative_steps = 0
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
@@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Get the number of draft tokens for each request.
|
||||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||||
# No draft token scheduled (common case).
|
||||
total_num_draft_tokens = 0
|
||||
total_num_logits = num_reqs
|
||||
cu_num_logits = torch.arange(
|
||||
num_reqs + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_draft_tokens = np.array(
|
||||
[
|
||||
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
|
||||
for req_id in req_ids
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
total_num_draft_tokens = int(num_draft_tokens.sum())
|
||||
total_num_logits = num_reqs + total_num_draft_tokens
|
||||
|
||||
np.cumsum(
|
||||
num_draft_tokens + 1,
|
||||
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
|
||||
)
|
||||
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
@@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens.
|
||||
combine_sampled_and_draft_tokens(
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
logits_indices = combine_sampled_and_draft_tokens(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.draft_tokens,
|
||||
cu_num_logits,
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
@@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
|
||||
)
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
|
||||
# Get num_computed_tokens.
|
||||
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
|
||||
# num_computed_tokens_cpu. This works for most cases.
|
||||
@@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
num_draft_tokens=total_num_draft_tokens,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
@@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
)
|
||||
|
||||
def sample(
|
||||
@@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if grammar_output is not None:
|
||||
# Apply grammar bitmask to the logits in-place.
|
||||
# TODO(woosuk): Make compatible with spec decoding.
|
||||
assert input_batch.num_draft_tokens == 0
|
||||
with async_barrier(self.structured_outputs_event):
|
||||
apply_grammar_bitmask(
|
||||
logits,
|
||||
@@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_buffers,
|
||||
)
|
||||
|
||||
# Sample tokens and compute logprobs (if needed).
|
||||
sampler_output = self.sampler(logits, sampling_metadata)
|
||||
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
|
||||
is_chunked_prefilling = input_batch.seq_lens < prefill_len
|
||||
num_sampled = (~is_chunked_prefilling).int()
|
||||
if input_batch.num_draft_tokens == 0:
|
||||
# No draft tokens (common case).
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
num_sampled = (~is_chunked_prefilling).int()
|
||||
else:
|
||||
# Draft tokens for spec decoding.
|
||||
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
||||
sampled_tokens, num_sampled = rejection_sample(
|
||||
sampler_output.sampled_token_ids,
|
||||
input_ids,
|
||||
input_batch.cu_num_logits,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
num_sampled *= ~is_chunked_prefilling
|
||||
sampler_output.sampled_token_ids = sampled_tokens
|
||||
# TODO(woosuk): Support logprobs with spec decoding.
|
||||
return sampler_output, num_sampled
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
@@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_sampled: torch.Tensor,
|
||||
) -> None:
|
||||
# Update the number of computed tokens.
|
||||
update_num_computed_tokens(
|
||||
post_update(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.last_sampled_tokens,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.cu_num_logits,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
# TODO(woosuk): Simplify this.
|
||||
@@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.req_states.prefill_len.np[idx_mapping_np],
|
||||
)
|
||||
|
||||
# Store the last sampled token ids.
|
||||
last_sampled = sampled_tokens
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
||||
sampling_metadata, input_batch.cu_num_logits
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
|
||||
Reference in New Issue
Block a user