[BugFix] Fix spec decoding edge case bugs (#31944)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -775,6 +775,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.encoder_cache_manager.free(request)
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
request.num_computed_tokens = 0
|
||||
request.spec_token_ids.clear()
|
||||
request.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.PREEMPTED, timestamp)
|
||||
|
||||
@@ -446,6 +446,32 @@ class InputBatch:
|
||||
|
||||
return req_index
|
||||
|
||||
def update_req_spec_token_ids(
|
||||
self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]]
|
||||
) -> None:
|
||||
req_id = request.req_id
|
||||
req_index = self.req_id_to_index[req_id]
|
||||
cur_spec_token_ids = self.spec_token_ids[req_index]
|
||||
# When speculative decoding is used with structured output,
|
||||
# the scheduler can drop draft tokens that do not
|
||||
# conform to the schema. This can result in
|
||||
# scheduler_output.scheduled_spec_decode_tokens being empty,
|
||||
# even when speculative decoding is enabled.
|
||||
cur_spec_token_ids.clear()
|
||||
spec_token_ids = scheduled_spec_tokens.get(req_id, ())
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
request.prev_num_draft_len = num_spec_tokens
|
||||
if not spec_token_ids:
|
||||
return
|
||||
|
||||
# For async scheduling, token_ids_cpu assigned from
|
||||
# spec_token_ids are placeholders and will be overwritten in
|
||||
# _prepare_input_ids.
|
||||
start_index = self.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids
|
||||
cur_spec_token_ids.extend(spec_token_ids)
|
||||
|
||||
def remove_request(self, req_id: str) -> int | None:
|
||||
"""This method must always be followed by a call to condense().
|
||||
|
||||
|
||||
@@ -925,6 +925,7 @@ class GPUModelRunner(
|
||||
# Update the states of the running/resumed requests.
|
||||
is_last_rank = get_pp_group().is_last_rank
|
||||
req_data = scheduler_output.scheduled_cached_reqs
|
||||
scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
|
||||
# Wait until valid_sampled_tokens_count is copied to cpu,
|
||||
# then use it to update actual num_computed_tokens of each request.
|
||||
@@ -938,20 +939,20 @@ class GPUModelRunner(
|
||||
num_output_tokens = req_data.num_output_tokens[i]
|
||||
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||
|
||||
# prev_num_draft_len is used in async scheduling mode with
|
||||
# spec decode. it indicates if need to update num_computed_tokens
|
||||
# of the request. for example:
|
||||
# fist step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# prev_num_draft_len = 0.
|
||||
# second step: num_computed_tokens = 100(prompt lenth),
|
||||
# spec_tokens = [a,b], prev_num_draft_len = 0.
|
||||
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
|
||||
# prev_num_draft_len = 2.
|
||||
# num_computed_tokens in first step and second step does't contain
|
||||
# the spec tokens length, but in third step it contains the
|
||||
# spec tokens length. we only need to update num_computed_tokens
|
||||
# when prev_num_draft_len > 0.
|
||||
if req_state.prev_num_draft_len:
|
||||
if req_state.prev_num_draft_len and self.use_async_scheduling:
|
||||
# prev_num_draft_len is used in async scheduling mode with
|
||||
# spec decode. it indicates if need to update num_computed_tokens
|
||||
# of the request. for example:
|
||||
# fist step: num_computed_tokens = 0, spec_tokens = [],
|
||||
# prev_num_draft_len = 0.
|
||||
# second step: num_computed_tokens = 100(prompt lenth),
|
||||
# spec_tokens = [a,b], prev_num_draft_len = 0.
|
||||
# third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d],
|
||||
# prev_num_draft_len = 2.
|
||||
# num_computed_tokens in first step and second step does't contain
|
||||
# the spec tokens length, but in third step it contains the
|
||||
# spec tokens length. we only need to update num_computed_tokens
|
||||
# when prev_num_draft_len > 0.
|
||||
if req_index is None:
|
||||
req_state.prev_num_draft_len = 0
|
||||
else:
|
||||
@@ -1035,34 +1036,13 @@ class GPUModelRunner(
|
||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
||||
|
||||
# Add spec_token_ids to token_ids_cpu.
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, []
|
||||
)
|
||||
num_spec_tokens = len(spec_token_ids)
|
||||
# For async scheduling, token_ids_cpu assigned from
|
||||
# spec_token_ids are placeholders and will be overwritten in
|
||||
# _prepare_input_ids.
|
||||
if num_spec_tokens:
|
||||
start_index = self.input_batch.num_tokens_no_spec[req_index]
|
||||
end_token_index = start_index + num_spec_tokens
|
||||
self.input_batch.token_ids_cpu[
|
||||
req_index, start_index:end_token_index
|
||||
] = spec_token_ids
|
||||
self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens)
|
||||
|
||||
# When speculative decoding is used with structured output,
|
||||
# the scheduler can drop draft tokens that do not
|
||||
# conform to the schema. This can result in
|
||||
# scheduler_output.scheduled_spec_decode_tokens being empty,
|
||||
# even when speculative decoding is enabled.
|
||||
self.input_batch.spec_token_ids[req_index].clear()
|
||||
self.input_batch.spec_token_ids[req_index].extend(spec_token_ids)
|
||||
|
||||
if self.use_async_scheduling:
|
||||
req_state.prev_num_draft_len = num_spec_tokens
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
for request in reqs_to_add:
|
||||
self.input_batch.add_request(request)
|
||||
self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens)
|
||||
|
||||
# Condense the batched states if there are gaps left by removed requests
|
||||
self.input_batch.condense()
|
||||
@@ -1519,7 +1499,6 @@ class GPUModelRunner(
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
num_draft_tokens = None
|
||||
spec_decode_metadata = None
|
||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||
else:
|
||||
@@ -1536,14 +1515,11 @@ class GPUModelRunner(
|
||||
) in scheduler_output.scheduled_spec_decode_tokens.items():
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
num_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
num_decode_draft_tokens[req_idx] = (
|
||||
len(draft_token_ids)
|
||||
if (
|
||||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]
|
||||
)
|
||||
else -1
|
||||
)
|
||||
if (
|
||||
self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
>= self.input_batch.num_prompt_tokens[req_idx]
|
||||
):
|
||||
num_decode_draft_tokens[req_idx] = len(draft_token_ids)
|
||||
spec_decode_metadata = self._calc_spec_decode_metadata(
|
||||
num_draft_tokens, cu_num_tokens
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user