diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b3ea24dac..a6bfa7a4a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -775,6 +775,7 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free(request) request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 + request.spec_token_ids.clear() request.num_preemptions += 1 if self.log_stats: request.record_event(EngineCoreEventType.PREEMPTED, timestamp) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 14bbd6578..dd8c66114 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -446,6 +446,32 @@ class InputBatch: return req_index + def update_req_spec_token_ids( + self, request: CachedRequestState, scheduled_spec_tokens: dict[str, list[int]] + ) -> None: + req_id = request.req_id + req_index = self.req_id_to_index[req_id] + cur_spec_token_ids = self.spec_token_ids[req_index] + # When speculative decoding is used with structured output, + # the scheduler can drop draft tokens that do not + # conform to the schema. This can result in + # scheduler_output.scheduled_spec_decode_tokens being empty, + # even when speculative decoding is enabled. + cur_spec_token_ids.clear() + spec_token_ids = scheduled_spec_tokens.get(req_id, ()) + num_spec_tokens = len(spec_token_ids) + request.prev_num_draft_len = num_spec_tokens + if not spec_token_ids: + return + + # For async scheduling, token_ids_cpu assigned from + # spec_token_ids are placeholders and will be overwritten in + # _prepare_input_ids. + start_index = self.num_tokens_no_spec[req_index] + end_token_index = start_index + num_spec_tokens + self.token_ids_cpu[req_index, start_index:end_token_index] = spec_token_ids + cur_spec_token_ids.extend(spec_token_ids) + def remove_request(self, req_id: str) -> int | None: """This method must always be followed by a call to condense(). diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 07d5c282c..adba78e55 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -925,6 +925,7 @@ class GPUModelRunner( # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs + scheduled_spec_tokens = scheduler_output.scheduled_spec_decode_tokens # Wait until valid_sampled_tokens_count is copied to cpu, # then use it to update actual num_computed_tokens of each request. @@ -938,20 +939,20 @@ class GPUModelRunner( num_output_tokens = req_data.num_output_tokens[i] req_index = self.input_batch.req_id_to_index.get(req_id) - # prev_num_draft_len is used in async scheduling mode with - # spec decode. it indicates if need to update num_computed_tokens - # of the request. for example: - # fist step: num_computed_tokens = 0, spec_tokens = [], - # prev_num_draft_len = 0. - # second step: num_computed_tokens = 100(prompt lenth), - # spec_tokens = [a,b], prev_num_draft_len = 0. - # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], - # prev_num_draft_len = 2. - # num_computed_tokens in first step and second step does't contain - # the spec tokens length, but in third step it contains the - # spec tokens length. we only need to update num_computed_tokens - # when prev_num_draft_len > 0. - if req_state.prev_num_draft_len: + if req_state.prev_num_draft_len and self.use_async_scheduling: + # prev_num_draft_len is used in async scheduling mode with + # spec decode. it indicates if need to update num_computed_tokens + # of the request. for example: + # fist step: num_computed_tokens = 0, spec_tokens = [], + # prev_num_draft_len = 0. + # second step: num_computed_tokens = 100(prompt lenth), + # spec_tokens = [a,b], prev_num_draft_len = 0. + # third step: num_computed_tokens = 100 + 2, spec_tokens = [c,d], + # prev_num_draft_len = 2. + # num_computed_tokens in first step and second step does't contain + # the spec tokens length, but in third step it contains the + # spec tokens length. we only need to update num_computed_tokens + # when prev_num_draft_len > 0. if req_index is None: req_state.prev_num_draft_len = 0 else: @@ -1035,34 +1036,13 @@ class GPUModelRunner( self.input_batch.num_tokens_no_spec[req_index] = end_token_index # Add spec_token_ids to token_ids_cpu. - spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get( - req_id, [] - ) - num_spec_tokens = len(spec_token_ids) - # For async scheduling, token_ids_cpu assigned from - # spec_token_ids are placeholders and will be overwritten in - # _prepare_input_ids. - if num_spec_tokens: - start_index = self.input_batch.num_tokens_no_spec[req_index] - end_token_index = start_index + num_spec_tokens - self.input_batch.token_ids_cpu[ - req_index, start_index:end_token_index - ] = spec_token_ids + self.input_batch.update_req_spec_token_ids(req_state, scheduled_spec_tokens) - # When speculative decoding is used with structured output, - # the scheduler can drop draft tokens that do not - # conform to the schema. This can result in - # scheduler_output.scheduled_spec_decode_tokens being empty, - # even when speculative decoding is enabled. - self.input_batch.spec_token_ids[req_index].clear() - self.input_batch.spec_token_ids[req_index].extend(spec_token_ids) - - if self.use_async_scheduling: - req_state.prev_num_draft_len = num_spec_tokens # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first. for request in reqs_to_add: self.input_batch.add_request(request) + self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) # Condense the batched states if there are gaps left by removed requests self.input_batch.condense() @@ -1519,7 +1499,6 @@ class GPUModelRunner( # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - num_draft_tokens = None spec_decode_metadata = None num_sampled_tokens = np.ones(num_reqs, dtype=np.int32) else: @@ -1536,14 +1515,11 @@ class GPUModelRunner( ) in scheduler_output.scheduled_spec_decode_tokens.items(): req_idx = self.input_batch.req_id_to_index[req_id] num_draft_tokens[req_idx] = len(draft_token_ids) - num_decode_draft_tokens[req_idx] = ( - len(draft_token_ids) - if ( - self.input_batch.num_computed_tokens_cpu[req_idx] - >= self.input_batch.num_prompt_tokens[req_idx] - ) - else -1 - ) + if ( + self.input_batch.num_computed_tokens_cpu[req_idx] + >= self.input_batch.num_prompt_tokens[req_idx] + ): + num_decode_draft_tokens[req_idx] = len(draft_token_ids) spec_decode_metadata = self._calc_spec_decode_metadata( num_draft_tokens, cu_num_tokens )