[BugFix][Spec Decoding] Fix negative accepted tokens metric crash (#33729)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -870,6 +870,66 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
assert stats.num_accepted_tokens_per_pos == expected[3]
|
||||
|
||||
|
||||
def test_spec_decoding_stats_empty_output():
|
||||
"""Test that spec decoding stats handle empty output tokens gracefully.
|
||||
|
||||
This is a regression test for a bug where empty sampled_token_ids
|
||||
would cause num_accepted = len([]) - 1 = -1, leading to a
|
||||
ValueError when incrementing a Prometheus counter with a negative value.
|
||||
"""
|
||||
num_spec_tokens = 3
|
||||
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
||||
requests = create_requests(num_requests=1, num_tokens=1)
|
||||
request = requests[0]
|
||||
req_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Initial schedule (prefill)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
|
||||
# Complete the prefill with a sampled token
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[req_id],
|
||||
req_id_to_index={req_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Add draft tokens for speculation
|
||||
draft_token_ids = DraftTokenIds([req_id], [[1, 2, 3]])
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
assert req_id in output.scheduled_spec_decode_tokens
|
||||
assert len(output.scheduled_spec_decode_tokens[req_id]) == 3
|
||||
|
||||
# Simulate empty output tokens (e.g., due to request abortion or error)
|
||||
# This would previously cause num_accepted = -1 and crash
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[req_id],
|
||||
req_id_to_index={req_id: 0},
|
||||
sampled_token_ids=[[]], # Empty output tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
# This should not raise an error
|
||||
engine_core_outputs = scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Spec decoding stats should be None since no tokens were generated
|
||||
scheduler_stats = (
|
||||
engine_core_outputs[0].scheduler_stats if engine_core_outputs else None
|
||||
)
|
||||
assert scheduler_stats is None or scheduler_stats.spec_decoding_stats is None
|
||||
|
||||
|
||||
def _assert_right_scheduler_output(
|
||||
output: SchedulerOutput,
|
||||
num_requests: int,
|
||||
|
||||
@@ -1284,7 +1284,7 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id)
|
||||
)
|
||||
if scheduled_spec_token_ids:
|
||||
if scheduled_spec_token_ids and generated_token_ids:
|
||||
num_draft_tokens = len(scheduled_spec_token_ids)
|
||||
num_accepted = len(generated_token_ids) - 1
|
||||
num_rejected = num_draft_tokens - num_accepted
|
||||
|
||||
Reference in New Issue
Block a user