[Core] Don't schedule spec tokens with prefill chunks (#33652)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-04 15:40:22 -08:00
committed by GitHub
parent ce498a6d61
commit fa4e0fb028
5 changed files with 129 additions and 29 deletions

View File

@@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output():
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
def test_no_spec_tokens_scheduled_for_prefill_chunks():
"""Test that draft tokens are ignored for prefill chunk requests.
When a request is being prefilled in chunks (chunked prefill), draft tokens
from `update_draft_token_ids` should be ignored until the prefill is complete.
The bug manifests when:
- A prefill chunk is scheduled
- Draft tokens are provided via update_draft_token_ids
- The next schedule has enough budget to include spec tokens
Without the fix, spec tokens would incorrectly be scheduled with the
remaining prefill tokens. With the fix, draft tokens are ignored for
prefill chunks.
"""
num_spec_tokens = 3
# Use budget of 50, with 80 token prompt:
# - First chunk: 50 tokens
# - Second chunk: 30 remaining + potentially 3 spec tokens = 33
# Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!)
# With fix: spec_token_ids cleared, so no spec tokens scheduled
scheduler = create_scheduler(
num_speculative_tokens=num_spec_tokens,
max_num_batched_tokens=50,
enable_chunked_prefill=True,
)
requests = create_requests(num_requests=1, num_tokens=80)
req = requests[0]
scheduler.add_request(req)
# First schedule - prefill chunk (50 of 80 tokens)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 1
assert output.num_scheduled_tokens[req.request_id] == 50
# Update from output (no sampled token since still prefilling)
req_to_index = {req.request_id: 0}
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Provide draft tokens while request is still in prefill.
# The fix ensures these are ignored for prefill chunks.
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# Second schedule - remaining 30 tokens of prefill
output = scheduler.schedule()
# KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens,
# NOT 33 (30 + 3 spec). Without the fix, this would be 33.
assert output.num_scheduled_tokens[req.request_id] == 30, (
f"Expected 30 tokens (remaining prefill only), "
f"got {output.num_scheduled_tokens[req.request_id]}. "
"Spec tokens should not be scheduled with prefill chunks."
)
# No spec tokens should be in the output
assert req.request_id not in output.scheduled_spec_decode_tokens, (
"Spec tokens should not be scheduled with prefill chunks"
)
# Update from output with a sampled token (prefill complete)
model_runner_output = ModelRunnerOutput(
req_ids=[req.request_id],
req_id_to_index=req_to_index,
sampled_token_ids=[[42]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
scheduler.update_from_output(output, model_runner_output)
# Now provide draft tokens - should be accepted since prefill is complete
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
scheduler.update_draft_token_ids(draft_token_ids)
# spec_token_ids SHOULD be set after prefill is complete
assert req.spec_token_ids == [1, 2, 3], (
f"spec_token_ids should be set after prefill, got {req.spec_token_ids}"
)
# Third schedule - decode phase with spec tokens
output = scheduler.schedule()
# 1 new token + 3 spec tokens = 4
assert output.num_scheduled_tokens[req.request_id] == 4
assert req.request_id in output.scheduled_spec_decode_tokens
assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,