[Core] Don't schedule spec tokens with prefill chunks (#33652)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -945,6 +945,100 @@ def test_spec_decoding_stats_empty_output():
|
||||
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
|
||||
|
||||
|
||||
def test_no_spec_tokens_scheduled_for_prefill_chunks():
|
||||
"""Test that draft tokens are ignored for prefill chunk requests.
|
||||
|
||||
When a request is being prefilled in chunks (chunked prefill), draft tokens
|
||||
from `update_draft_token_ids` should be ignored until the prefill is complete.
|
||||
|
||||
The bug manifests when:
|
||||
- A prefill chunk is scheduled
|
||||
- Draft tokens are provided via update_draft_token_ids
|
||||
- The next schedule has enough budget to include spec tokens
|
||||
|
||||
Without the fix, spec tokens would incorrectly be scheduled with the
|
||||
remaining prefill tokens. With the fix, draft tokens are ignored for
|
||||
prefill chunks.
|
||||
"""
|
||||
num_spec_tokens = 3
|
||||
# Use budget of 50, with 80 token prompt:
|
||||
# - First chunk: 50 tokens
|
||||
# - Second chunk: 30 remaining + potentially 3 spec tokens = 33
|
||||
# Without fix: num_scheduled_spec_tokens = 33 + 50 - 80 = 3 (BUG!)
|
||||
# With fix: spec_token_ids cleared, so no spec tokens scheduled
|
||||
scheduler = create_scheduler(
|
||||
num_speculative_tokens=num_spec_tokens,
|
||||
max_num_batched_tokens=50,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
requests = create_requests(num_requests=1, num_tokens=80)
|
||||
req = requests[0]
|
||||
scheduler.add_request(req)
|
||||
|
||||
# First schedule - prefill chunk (50 of 80 tokens)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
assert output.num_scheduled_tokens[req.request_id] == 50
|
||||
|
||||
# Update from output (no sampled token since still prefilling)
|
||||
req_to_index = {req.request_id: 0}
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Provide draft tokens while request is still in prefill.
|
||||
# The fix ensures these are ignored for prefill chunks.
|
||||
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
# Second schedule - remaining 30 tokens of prefill
|
||||
output = scheduler.schedule()
|
||||
# KEY ASSERTION: Should schedule exactly the remaining 30 prefill tokens,
|
||||
# NOT 33 (30 + 3 spec). Without the fix, this would be 33.
|
||||
assert output.num_scheduled_tokens[req.request_id] == 30, (
|
||||
f"Expected 30 tokens (remaining prefill only), "
|
||||
f"got {output.num_scheduled_tokens[req.request_id]}. "
|
||||
"Spec tokens should not be scheduled with prefill chunks."
|
||||
)
|
||||
# No spec tokens should be in the output
|
||||
assert req.request_id not in output.scheduled_spec_decode_tokens, (
|
||||
"Spec tokens should not be scheduled with prefill chunks"
|
||||
)
|
||||
|
||||
# Update from output with a sampled token (prefill complete)
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[42]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Now provide draft tokens - should be accepted since prefill is complete
|
||||
draft_token_ids = DraftTokenIds([req.request_id], [[1, 2, 3]])
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
# spec_token_ids SHOULD be set after prefill is complete
|
||||
assert req.spec_token_ids == [1, 2, 3], (
|
||||
f"spec_token_ids should be set after prefill, got {req.spec_token_ids}"
|
||||
)
|
||||
|
||||
# Third schedule - decode phase with spec tokens
|
||||
output = scheduler.schedule()
|
||||
# 1 new token + 3 spec tokens = 4
|
||||
assert output.num_scheduled_tokens[req.request_id] == 4
|
||||
assert req.request_id in output.scheduled_spec_decode_tokens
|
||||
assert len(output.scheduled_spec_decode_tokens[req.request_id]) == num_spec_tokens
|
||||
|
||||
|
||||
def _assert_right_scheduler_output(
|
||||
output: SchedulerOutput,
|
||||
num_requests: int,
|
||||
|
||||
@@ -17,33 +17,22 @@ class AsyncScheduler(Scheduler):
|
||||
|
||||
def _update_after_schedule(self, scheduler_output: SchedulerOutput) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
has_structured_output_requests = False
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
has_structured_output_requests |= request.use_structured_output
|
||||
pending_structured_output_tokens |= (
|
||||
if request.is_prefill_chunk:
|
||||
continue
|
||||
|
||||
scheduler_output.pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new draft/spec tokens.
|
||||
# We will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = self._spec_token_placeholders
|
||||
|
||||
scheduler_output.has_structured_output_requests = has_structured_output_requests
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
)
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new draft/spec tokens.
|
||||
# We will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = self._spec_token_placeholders
|
||||
|
||||
def _update_request_with_output(
|
||||
self, request: Request, new_token_ids: list[int]
|
||||
|
||||
@@ -912,6 +912,12 @@ class Scheduler(SchedulerInterface):
|
||||
for req_id, num_scheduled_token in num_scheduled_tokens.items():
|
||||
request = self.requests[req_id]
|
||||
request.num_computed_tokens += num_scheduled_token
|
||||
request.is_prefill_chunk = request.num_computed_tokens < (
|
||||
request.num_tokens + request.num_output_placeholders
|
||||
)
|
||||
scheduler_output.has_structured_output_requests |= (
|
||||
request.use_structured_output
|
||||
)
|
||||
|
||||
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
|
||||
# may be updated again in _update_from_output for speculative
|
||||
@@ -1562,6 +1568,12 @@ class Scheduler(SchedulerInterface):
|
||||
# The request may have been finished. Skip.
|
||||
continue
|
||||
|
||||
if request.is_prefill_chunk:
|
||||
# Ignore draft tokens for prefill chunks.
|
||||
if request.spec_token_ids:
|
||||
request.spec_token_ids = []
|
||||
continue
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if self.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
|
||||
@@ -147,6 +147,9 @@ class Request:
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
|
||||
# True if this request is scheduled as a non-final prefill chunk.
|
||||
self.is_prefill_chunk = False
|
||||
|
||||
# The number of NaNs in logits. A value greater than 0
|
||||
# indicates that the output is corrupted
|
||||
self.num_nans_in_logits = 0
|
||||
|
||||
@@ -16,21 +16,21 @@ class DraftTokensHandler:
|
||||
|
||||
self.req_ids: list[str] = []
|
||||
self.draft_tokens_np: np.ndarray | None = None
|
||||
self.num_draft_tokens: int = 0
|
||||
|
||||
def set_draft_tokens(
|
||||
self, input_batch: InputBatch, draft_tokens: torch.Tensor
|
||||
) -> None:
|
||||
self.req_ids = input_batch.req_ids
|
||||
self.num_draft_tokens = draft_tokens.shape[1]
|
||||
if not input_batch.has_structured_output_reqs:
|
||||
# No draft token validation needs to be performed by
|
||||
# the scheduler for this batch.
|
||||
if self.req_ids:
|
||||
self.req_ids = []
|
||||
self.draft_tokens_np = None
|
||||
return
|
||||
|
||||
# For spec decoding + structured outputs, we must transfer the
|
||||
# draft tokens back to the scheduler for grammar validation.
|
||||
self.req_ids = input_batch.req_ids
|
||||
current_stream = torch.cuda.current_stream(self.device)
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
@@ -38,8 +38,10 @@ class DraftTokensHandler:
|
||||
self.copy_event.record()
|
||||
|
||||
def get_draft_tokens(self) -> DraftTokenIds | None:
|
||||
if self.draft_tokens_np is None:
|
||||
return None
|
||||
|
||||
self.copy_event.synchronize()
|
||||
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())
|
||||
if self.draft_tokens_np is not None:
|
||||
self.copy_event.synchronize()
|
||||
draft_token_ids = self.draft_tokens_np.tolist()
|
||||
else:
|
||||
# This case only happens when async scheduling is disabled.
|
||||
draft_token_ids = [[-1] * self.num_draft_tokens for _ in self.req_ids]
|
||||
return DraftTokenIds(self.req_ids, draft_token_ids)
|
||||
|
||||
Reference in New Issue
Block a user