[AsyncScheduling] Don't schedule past request max_tokens (#27922)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-11-04 09:06:28 -08:00
committed by GitHub
parent c9f66da8fd
commit 938a81692e
3 changed files with 14 additions and 4 deletions

View File

@@ -34,15 +34,20 @@ def test_stop_by_max_tokens(max_tokens: int):
requests = create_requests(num_requests=2, max_tokens=max_tokens)
req0, req1 = requests
expected_total_num_scheduled_tokens = 0
sched_outputs: deque[SchedulerOutput] = deque()
scheduler.add_request(req0)
sched_outputs.append(scheduler.schedule())
expected_total_num_scheduled_tokens += req0.num_prompt_tokens + max_tokens - 1
scheduler.add_request(req1)
sched_outputs.append(scheduler.schedule())
expected_total_num_scheduled_tokens += req1.num_prompt_tokens + max_tokens - 1
total_num_scheduled_tokens = 0
while sched_outputs:
sched_output = sched_outputs.popleft()
total_num_scheduled_tokens += sched_output.total_num_scheduled_tokens
model_runner_output = _make_model_runner_output(sched_output)
scheduler.update_from_output(sched_output, model_runner_output)
@@ -53,6 +58,8 @@ def test_stop_by_max_tokens(max_tokens: int):
assert scheduler.get_num_unfinished_requests() == 0
assert req0.num_output_tokens == max_tokens
assert req1.num_output_tokens == max_tokens
# Ensure we aren't scheduling more tokens than necessary.
assert total_num_scheduled_tokens == expected_total_num_scheduled_tokens
def test_abort():