diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 063d0a644..b29df468f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output(): assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None +def test_no_spec_tokens_scheduled_for_prefill_chunks(): + """Test that draft tokens are ignored for prefill chunk requests. + + When a request is being prefilled in chunks (chunked prefill), draft tokens + from `update_draft_token_ids` should be ignored until the prefill is complete. + + The bug manifests when: + - A prefill chunk is scheduled + - Draft tokens are provided via update_draft_token_ids + - The next schedule has enough budget to include spec tokens + + Without the fix, spec tokens would incorrectly be scheduled with the + remaining prefill tokens. With the fix, draft tokens are ignored for + prefill chunks. + """ + num_spec_tokens = 3 + # Use budget of 50, with 80 token prompt: + # - First chunk: 50 tokens + # - Second chunk: 30 remaining + potentially 3 spec tokens = 33 + # Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!) + # With fix: spec_token_ids cleared, so no spec tokens scheduled + scheduler = create_scheduler( + num_speculative_tokens=num_spec_tokens, + max_num_batched_tokens=50, + enable_chunked_prefill=True, + ) + requests = create_requests(num_requests=1, num_tokens=80) + req = requests[0] + scheduler.add_request(req) + + # First schedule - prefill chunk (50 of 80 tokens) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.num_scheduled_tokens[req.request_id] == 50 + + # Update from output (no sampled token since still prefilling) + req_to_index = {req.request_id: 0} + model_runner_output = ModelRunnerOutput( + req_ids=[req.request_id], + req_id_to_index=req_to_index, + sampled_token_ids=[[]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_runner_output) + + # Provide draft tokens while request is still in prefill. + # The fix ensures these are ignored for prefill chunks. + draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]]) + scheduler.update_draft_token_ids(draft_token_ids) + + # Second schedule - remaining 30 tokens of prefill + output = scheduler.schedule() + # KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens, + # NOT 33 (30 + 3 spec). Without the fix, this would be 33. + assert output.num_scheduled_tokens[req.request_id] == 30, ( + f"Expected 30 tokens (remaining prefill only), " + f"got {output.num_scheduled_tokens[req.request_id]}. " + "Spec tokens should not be scheduled with prefill chunks." + ) + # No spec tokens should be in the output + assert req.request_id not in output.scheduled_spec_decode_tokens, ( + "Spec tokens should not be scheduled with prefill chunks" + ) + + # Update from output with a sampled token (prefill complete) + model_runner_output = ModelRunnerOutput( + req_ids=[req.request_id], + req_id_to_index=req_to_index, + sampled_token_ids=[[42]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_runner_output) + + # Now provide draft tokens - should be accepted since prefill is complete + draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]]) + scheduler.update_draft_token_ids(draft_token_ids) + + # spec_token_ids SHOULD be set after prefill is complete + assert req.spec_token_ids == [1, 2, 3], ( + f"spec_token_ids should be set after prefill, got {req.spec_token_ids}" + ) + + # Third schedule - decode phase with spec tokens + output = scheduler.schedule() + # 1 new token + 3 spec tokens = 4 + assert output.num_scheduled_tokens[req.request_id] == 4 + assert req.request_id in output.scheduled_spec_decode_tokens + assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens + + def _assert_right_scheduler_output( output: SchedulerOutput, num_requests: int, diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py index 23c610f3b..0b3958dbc 100644 --- a/vllm/v1/core/sched/async_scheduler.py +++ b/vllm/v1/core/sched/async_scheduler.py @@ -17,33 +17,22 @@ class AsyncScheduler(Scheduler): def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None: super()._update_after_schedule(scheduler_output) - has_structured_output_requests = False - pending_structured_output_tokens = False spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens for req_id in scheduler_output.num_scheduled_tokens: request = self.requests[req_id] - has_structured_output_requests |= request.use_structured_output - pending_structured_output_tokens |= ( + if request.is_prefill_chunk: + continue + + scheduler_output.pending_structured_output_tokens |= ( request.use_structured_output and request.num_output_placeholders > 0 ) + # The request will generate a new token plus num_spec_tokens + # in this scheduling step. cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ())) - if ( - request.num_computed_tokens - == request.num_tokens - + request.num_output_placeholders - + cur_num_spec_tokens - ): - # The request will generate a new token plus num_spec_tokens - # in this scheduling step. - request.num_output_placeholders += 1 + cur_num_spec_tokens - # Add placeholders for the new draft/spec tokens. - # We will update the actual spec token ids in the worker process. - request.spec_token_ids = self._spec_token_placeholders - - scheduler_output.has_structured_output_requests = has_structured_output_requests - scheduler_output.pending_structured_output_tokens = ( - pending_structured_output_tokens - ) + request.num_output_placeholders += 1 + cur_num_spec_tokens + # Add placeholders for the new draft/spec tokens. + # We will update the actual spec token ids in the worker process. + request.spec_token_ids = self._spec_token_placeholders def _update_request_with_output( self, request: Request, new_token_ids: list[int] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 869b53601..88d1a78df 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -912,6 +912,12 @@ class Scheduler(SchedulerInterface): for req_id, num_scheduled_token in num_scheduled_tokens.items(): request = self.requests[req_id] request.num_computed_tokens += num_scheduled_token + request.is_prefill_chunk = request.num_computed_tokens < ( + request.num_tokens + request.num_output_placeholders + ) + scheduler_output.has_structured_output_requests |= ( + request.use_structured_output + ) # NOTE: _free_encoder_inputs relies on num_computed_tokens, which # may be updated again in _update_from_output for speculative @@ -1562,6 +1568,12 @@ class Scheduler(SchedulerInterface): # The request may have been finished. Skip. continue + if request.is_prefill_chunk: + # Ignore draft tokens for prefill chunks. + if request.spec_token_ids: + request.spec_token_ids = [] + continue + # Add newly generated spec token ids to the request. if self.structured_output_manager.should_advance(request): metadata = request.structured_output_request diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 117478a92..e9d3df442 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -147,6 +147,9 @@ class Request: # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 + # True if this request is scheduled as a non-final prefill chunk. + self.is_prefill_chunk = False + # The number of NaNs in logits. A value greater than 0 # indicates that the output is corrupted self.num_nans_in_logits = 0 diff --git a/vllm/v1/worker/gpu/spec_decode/utils.py b/vllm/v1/worker/gpu/spec_decode/utils.py index ddeb99a71..e1fa21aeb 100644 --- a/vllm/v1/worker/gpu/spec_decode/utils.py +++ b/vllm/v1/worker/gpu/spec_decode/utils.py @@ -16,21 +16,21 @@ class DraftTokensHandler: self.req_ids: list[str] = [] self.draft_tokens_np: np.ndarray | None = None + self.num_draft_tokens: int = 0 def set_draft_tokens( self, input_batch: InputBatch, draft_tokens: torch.Tensor ) -> None: + self.req_ids = input_batch.req_ids + self.num_draft_tokens = draft_tokens.shape[1] if not input_batch.has_structured_output_reqs: # No draft token validation needs to be performed by # the scheduler for this batch. - if self.req_ids: - self.req_ids = [] self.draft_tokens_np = None return # For spec decoding + structured outputs, we must transfer the # draft tokens back to the scheduler for grammar validation. - self.req_ids = input_batch.req_ids current_stream = torch.cuda.current_stream(self.device) self.copy_stream.wait_stream(current_stream) with torch.cuda.stream(self.copy_stream): @@ -38,8 +38,10 @@ class DraftTokensHandler: self.copy_event.record() def get_draft_tokens(self) -> DraftTokenIds | None: - if self.draft_tokens_np is None: - return None - - self.copy_event.synchronize() - return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist()) + if self.draft_tokens_np is not None: + self.copy_event.synchronize() + draft_token_ids = self.draft_tokens_np.tolist() + else: + # This case only happens when async scheduling is disabled. + draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids] + return DraftTokenIds(self.req_ids, draft_token_ids)