[Core] Don't schedule spec tokens with prefill chunks (#33652)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-04 15:40:22 -08:00
committed by GitHub
parent ce498a6d61
commit fa4e0fb028
5 changed files with 129 additions and 29 deletions

View File

@@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output():
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
def test_no_spec_tokens_scheduled_for_prefill_chunks():
"""Test that draft tokens are ignored for prefill chunk requests.
When a request is being prefilled in chunks (chunked prefill), draft tokens
from `update_draft_token_ids` should be ignored until the prefill is complete.
The bug manifests when:
- A prefill chunk is scheduled
- Draft tokens are provided via update_draft_token_ids
- The next schedule has enough budget to include spec tokens
Without the fix, spec tokens would incorrectly be scheduled with the
remaining prefill tokens. With the fix, draft tokens are ignored for
prefill chunks.
"""
num_spec_tokens = 3
# Use budget of 50, with 80 token prompt:
# - First chunk: 50 tokens
# - Second chunk: 30 remaining + potentially 3 spec tokens = 33
# Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!)
# With fix: spec_token_ids cleared, so no spec tokens scheduled
scheduler = create_scheduler(
num_speculative_tokens=num_spec_tokens,
max_num_batched_tokens=50,
enable_chunked_prefill=True,
)
requests = create_requests(num_requests=1, num_tokens=80)
req = requests[0]
scheduler.add_request(req)
# First schedule - prefill chunk (50 of 80 tokens)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.num_scheduled_tokens[req.request_id] == 50
# Update from output (no sampled token since still prefilling)
req_to_index = {req.request_id: 0}
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Provide draft tokens while request is still in prefill.
# The fix ensures these are ignored for prefill chunks.
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# Second schedule - remaining 30 tokens of prefill
output = scheduler.schedule()
# KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens,
# NOT 33 (30 + 3 spec). Without the fix, this would be 33.
assert output.num_scheduled_tokens[req.request_id] == 30, (
f"Expected 30 tokens (remaining prefill only), "
f"got {output.num_scheduled_tokens[req.request_id]}. "
"Spec tokens should not be scheduled with prefill chunks."
)
# No spec tokens should be in the output
assert req.request_id not in output.scheduled_spec_decode_tokens, (
"Spec tokens should not be scheduled with prefill chunks"
)
# Update from output with a sampled token (prefill complete)
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Now provide draft tokens - should be accepted since prefill is complete
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# spec_token_ids SHOULD be set after prefill is complete
assert req.spec_token_ids == [1, 2, 3], (
f"spec_token_ids should be set after prefill, got {req.spec_token_ids}"
)
# Third schedule - decode phase with spec tokens
output = scheduler.schedule()
# 1 new token + 3 spec tokens = 4
assert output.num_scheduled_tokens[req.request_id] == 4
assert req.request_id in output.scheduled_spec_decode_tokens
assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,

View File

@@ -17,33 +17,22 @@ class AsyncScheduler(Scheduler):
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
super()._update_after_schedule(scheduler_output)
has_structured_output_requests = False
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
has_structured_output_requests |= request.use_structured_output
pending_structured_output_tokens |= (
if request.is_prefill_chunk:
continue
scheduler_output.pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (
request.num_computed_tokens
== request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new draft/spec tokens.
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = self._spec_token_placeholders
scheduler_output.has_structured_output_requests = has_structured_output_requests
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new draft/spec tokens.
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = self._spec_token_placeholders
def _update_request_with_output(
self, request: Request, new_token_ids: list[int]

View File

@@ -912,6 +912,12 @@ class Scheduler(SchedulerInterface):
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token
request.is_prefill_chunk = request.num_computed_tokens < (
request.num_tokens + request.num_output_placeholders
)
scheduler_output.has_structured_output_requests |= (
request.use_structured_output
)
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
@@ -1562,6 +1568,12 @@ class Scheduler(SchedulerInterface):
# The request may have been finished. Skip.
continue
if request.is_prefill_chunk:
# Ignore draft tokens for prefill chunks.
if request.spec_token_ids:
request.spec_token_ids = []
continue
# Add newly generated spec token ids to the request.
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request

View File

@@ -147,6 +147,9 @@ class Request:
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
# True if this request is scheduled as a non-final prefill chunk.
self.is_prefill_chunk = False
# The number of NaNs in logits. A value greater than 0
# indicates that the output is corrupted
self.num_nans_in_logits = 0

View File

@@ -16,21 +16,21 @@ class DraftTokensHandler:
self.req_ids: list[str] = []
self.draft_tokens_np: np.ndarray | None = None
self.num_draft_tokens: int = 0
def set_draft_tokens(
self, input_batch: InputBatch, draft_tokens: torch.Tensor
) -> None:
self.req_ids = input_batch.req_ids
self.num_draft_tokens = draft_tokens.shape[1]
if not input_batch.has_structured_output_reqs:
# No draft token validation needs to be performed by
# the scheduler for this batch.
if self.req_ids:
self.req_ids = []
self.draft_tokens_np = None
return
# For spec decoding + structured outputs, we must transfer the
# draft tokens back to the scheduler for grammar validation.
self.req_ids = input_batch.req_ids
current_stream = torch.cuda.current_stream(self.device)
self.copy_stream.wait_stream(current_stream)
with torch.cuda.stream(self.copy_stream):
@@ -38,8 +38,10 @@ class DraftTokensHandler:
self.copy_event.record()
def get_draft_tokens(self) -> DraftTokenIds | None:
if self.draft_tokens_np is None:
return None
self.copy_event.synchronize()
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())
if self.draft_tokens_np is not None:
self.copy_event.synchronize()
draft_token_ids = self.draft_tokens_np.tolist()
else:
# This case only happens when async scheduling is disabled.
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
return DraftTokenIds(self.req_ids, draft_token_ids)