[BugFix] Fix spec decoding edge case bugs (#31944)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-01-07 23:31:03 -08:00
committed by GitHub
parent 791b2fc30a
commit 287b37cda4
3 changed files with 49 additions and 46 deletions

View File

@@ -775,6 +775,7 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
request.spec_token_ids.clear()
request.num_preemptions += 1
if self.log_stats:
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)

View File

@@ -446,6 +446,32 @@ class InputBatch:
return req_index
def update_req_spec_token_ids(
self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
) -> None:
req_id = request.req_id
req_index = self.req_id_to_index[req_id]
cur_spec_token_ids = self.spec_token_ids[req_index]
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
cur_spec_token_ids.clear()
spec_token_ids = scheduled_spec_tokens.get(req_id, ())
num_spec_tokens = len(spec_token_ids)
request.prev_num_draft_len = num_spec_tokens
if not spec_token_ids:
return
# For async scheduling, token_ids_cpu assigned from
# spec_token_ids are placeholders and will be overwritten in
# _prepare_input_ids.
start_index = self.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
cur_spec_token_ids.extend(spec_token_ids)
def remove_request(self, req_id: str) -> int | None:
"""This method must always be followed by a call to condense().

View File

@@ -925,6 +925,7 @@ class GPUModelRunner(
# Update the states of the running/resumed requests.
is_last_rank = get_pp_group().is_last_rank
req_data = scheduler_output.scheduled_cached_reqs
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
# Wait until valid_sampled_tokens_count is copied to cpu,
# then use it to update actual num_computed_tokens of each request.
@@ -938,20 +939,20 @@ class GPUModelRunner(
num_output_tokens = req_data.num_output_tokens[i]
req_index = self.input_batch.req_id_to_index.get(req_id)
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt lenth),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
if req_state.prev_num_draft_len:
if req_state.prev_num_draft_len and self.use_async_scheduling:
# prev_num_draft_len is used in async scheduling mode with
# spec decode. it indicates if need to update num_computed_tokens
# of the request. for example:
# fist step: num_computed_tokens = 0, spec_tokens = [],
# prev_num_draft_len = 0.
# second step: num_computed_tokens = 100(prompt lenth),
# spec_tokens = [a,b], prev_num_draft_len = 0.
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
# prev_num_draft_len = 2.
# num_computed_tokens in first step and second step does't contain
# the spec tokens length, but in third step it contains the
# spec tokens length. we only need to update num_computed_tokens
# when prev_num_draft_len > 0.
if req_index is None:
req_state.prev_num_draft_len = 0
else:
@@ -1035,34 +1036,13 @@ class GPUModelRunner(
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, []
)
num_spec_tokens = len(spec_token_ids)
# For async scheduling, token_ids_cpu assigned from
# spec_token_ids are placeholders and will be overwritten in
# _prepare_input_ids.
if num_spec_tokens:
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index
] = spec_token_ids
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
# When speculative decoding is used with structured output,
# the scheduler can drop draft tokens that do not
# conform to the schema. This can result in
# scheduler_output.scheduled_spec_decode_tokens being empty,
# even when speculative decoding is enabled.
self.input_batch.spec_token_ids[req_index].clear()
self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
if self.use_async_scheduling:
req_state.prev_num_draft_len = num_spec_tokens
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
for request in reqs_to_add:
self.input_batch.add_request(request)
self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens)
# Condense the batched states if there are gaps left by removed requests
self.input_batch.condense()
@@ -1519,7 +1499,6 @@ class GPUModelRunner(
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
num_draft_tokens = None
spec_decode_metadata = None
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
else:
@@ -1536,14 +1515,11 @@ class GPUModelRunner(
) in scheduler_output.scheduled_spec_decode_tokens.items():
req_idx = self.input_batch.req_id_to_index[req_id]
num_draft_tokens[req_idx] = len(draft_token_ids)
num_decode_draft_tokens[req_idx] = (
len(draft_token_ids)
if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]
)
else -1
)
if (
self.input_batch.num_computed_tokens_cpu[req_idx]
>= self.input_batch.num_prompt_tokens[req_idx]
):
num_decode_draft_tokens[req_idx] = len(draft_token_ids)
spec_decode_metadata = self._calc_spec_decode_metadata(
num_draft_tokens, cu_num_tokens
)