diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index e4b77f24b..2304bf7ec 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -2040,91 +2040,103 @@ def test_priority_scheduling_mixed_priority_and_arrival(): assert scheduled_req_ids == ["3", "2", "1", "0"] -# This test had previously been passing due to its use of duplicate -# request ids which resulted in incorrect behavior. -# Now that the duplicate req ids had been fixed it fails and -# investigation is needed into whether the priority scheduling -# preemption logic is working as designed or not. -@pytest.mark.skip("needs investigation") def test_priority_scheduling_preemption(): - """Test that priority scheduling preempts - lower priority requests when memory is constrained.""" - # Create scheduler with very limited memory to force preemption + """Test that under KV block pressure the scheduler preempts the + lowest-priority *running* request, not the highest-priority one. + + A low-priority request starts running first. Then a high-priority + request arrives and is admitted to running. When block pressure + builds, the scheduler preempts the low-priority running request + while keeping the high-priority one. + + Block math + ---------- + block_size = 16, num_blocks = 6 (1 null → 5 usable). + + Phase 1: lo1 (priority 5, 32 tokens) → 2 blocks. 3 free. + Decode → lo1 has 33 tokens (needs 3rd block on next schedule). + Phase 2: hi1 (priority 0, 32 tokens) arrives. + schedule() allocates lo1's 3rd block (3 used) and admits + hi1 (2 blocks) → 5 used, 0 free. Both running. + Decode → lo1 34 tokens, hi1 33 tokens. + Phase 3: schedule() → hi1 needs 3rd block, 0 free → preemption. + lo1 (priority 5) is preempted, hi1 (priority 0) survives. + """ + block_size = 16 + num_blocks = 6 # 1 null block → 5 usable + num_tokens = block_size * 2 # 32 tokens = exactly 2 blocks + scheduler = create_scheduler_with_priority( - max_num_seqs=3, # Allow multiple requests + max_num_seqs=3, max_num_batched_tokens=200, - num_blocks=6, # Very limited blocks to force memory pressure - block_size=16, # Standard block size + num_blocks=num_blocks, + block_size=block_size, ) - # Create initial low-priority requests that will consume most memory - low_priority_requests = create_requests_with_priority( - num_requests=2, - priorities=[5, 5], # Low priority - arrival_times=[1.0, 2.0], - num_tokens=30, # Large enough to consume significant memory, - req_ids=["lo1", "lo2"], - ) - - # Add and schedule low priority requests - for request in low_priority_requests: - scheduler.add_request(request) + # --- Phase 1: low-priority request starts running --- + lo1 = create_requests_with_priority( + num_requests=1, + priorities=[5], + arrival_times=[1.0], + num_tokens=num_tokens, + req_ids=["lo1"], + )[0] + scheduler.add_request(lo1) output = scheduler.schedule() - assert len(output.scheduled_new_reqs) == 2 + assert len(output.scheduled_new_reqs) == 1 - # Simulate model execution to move requests to running state + # Decode: lo1 now has 33 tokens (crosses 32-token boundary). model_output = ModelRunnerOutput( - req_ids=[req.request_id for req in low_priority_requests], - req_id_to_index={ - req.request_id: i for i, req in enumerate(low_priority_requests) - }, - sampled_token_ids=[[100] for _ in low_priority_requests], + req_ids=["lo1"], + req_id_to_index={"lo1": 0}, + sampled_token_ids=[[100]], logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) scheduler.update_from_output(output, model_output) - # Verify both requests are running - assert len(scheduler.running) == 2 - - # Now add a high-priority request that requires memory allocation - # This should trigger preemption due to memory constraints - high_priority_request = create_requests_with_priority( + # --- Phase 2: high-priority request arrives AFTER lo1 is running --- + hi1 = create_requests_with_priority( num_requests=1, - priorities=[0], # High priority - arrival_times=[3.0], - num_tokens=30, # Large enough to require significant memory + priorities=[0], + arrival_times=[2.0], + num_tokens=num_tokens, req_ids=["hi1"], )[0] + scheduler.add_request(hi1) - scheduler.add_request(high_priority_request) + # schedule(): lo1 gets its 3rd block (3 used), hi1 admitted (5 used, + # 0 free). Both are now running. + output = scheduler.schedule() + assert any(r.req_id == "hi1" for r in output.scheduled_new_reqs) + assert len(scheduler.running) == 2 - # Schedule again - this should trigger - # preemption when trying to allocate memory + # Decode: lo1 → 34 tokens, hi1 → 33 tokens. + model_output = ModelRunnerOutput( + req_ids=["lo1", "hi1"], + req_id_to_index={"lo1": 0, "hi1": 1}, + sampled_token_ids=[[101], [100]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(output, model_output) + + # --- Phase 3: preemption with mixed-priority running requests --- + # hi1 needs a 3rd block but 0 are free. The scheduler picks the + # lowest-priority running request to preempt: + # max(running, key=(priority, arrival_time)) → lo1 (5 > 0). output = scheduler.schedule() - # Due to the scheduler's design, if preemption happens - # during running request scheduling, - # waiting requests won't be scheduled in the same step - # Let's check if preemption occurred by looking at the waiting queue - - # If preemption happened, we should see requests in the - # waiting queue - if len(scheduler.waiting) > 1: # high priority + preempted request - # Preemption occurred - verify the high priority request - # gets scheduled next - output2 = scheduler.schedule() - assert len(output2.scheduled_new_reqs) == 1 - # High priority request - assert output2.scheduled_new_reqs[0].req_id == "hi1" - else: - # No preemption needed - all requests fit - # This is also valid behavior if memory allows - assert len(output.scheduled_new_reqs) == 1 - # High priority request - assert output.scheduled_new_reqs[0].req_id == "hi1" + lo1_req = scheduler.requests["lo1"] + assert lo1_req.status == RequestStatus.PREEMPTED, ( + "Expected low-priority 'lo1' to be preempted" + ) + assert any(req.request_id == "hi1" for req in scheduler.running), ( + "High-priority 'hi1' should still be running" + ) def test_priority_scheduling_no_preemption_when_space_available():