Fix priority preemption regression test in scheduler (#37051)
Signed-off-by: HarshRathva <harshrathvaai@gmail.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -2040,91 +2040,103 @@ def test_priority_scheduling_mixed_priority_and_arrival():
|
||||
assert scheduled_req_ids == ["3", "2", "1", "0"]
|
||||
|
||||
|
||||
# This test had previously been passing due to its use of duplicate
|
||||
# request ids which resulted in incorrect behavior.
|
||||
# Now that the duplicate req ids had been fixed it fails and
|
||||
# investigation is needed into whether the priority scheduling
|
||||
# preemption logic is working as designed or not.
|
||||
@pytest.mark.skip("needs investigation")
|
||||
def test_priority_scheduling_preemption():
|
||||
"""Test that priority scheduling preempts
|
||||
lower priority requests when memory is constrained."""
|
||||
# Create scheduler with very limited memory to force preemption
|
||||
"""Test that under KV block pressure the scheduler preempts the
|
||||
lowest-priority *running* request, not the highest-priority one.
|
||||
|
||||
A low-priority request starts running first. Then a high-priority
|
||||
request arrives and is admitted to running. When block pressure
|
||||
builds, the scheduler preempts the low-priority running request
|
||||
while keeping the high-priority one.
|
||||
|
||||
Block math
|
||||
----------
|
||||
block_size = 16, num_blocks = 6 (1 null → 5 usable).
|
||||
|
||||
Phase 1: lo1 (priority 5, 32 tokens) → 2 blocks. 3 free.
|
||||
Decode → lo1 has 33 tokens (needs 3rd block on next schedule).
|
||||
Phase 2: hi1 (priority 0, 32 tokens) arrives.
|
||||
schedule() allocates lo1's 3rd block (3 used) and admits
|
||||
hi1 (2 blocks) → 5 used, 0 free. Both running.
|
||||
Decode → lo1 34 tokens, hi1 33 tokens.
|
||||
Phase 3: schedule() → hi1 needs 3rd block, 0 free → preemption.
|
||||
lo1 (priority 5) is preempted, hi1 (priority 0) survives.
|
||||
"""
|
||||
block_size = 16
|
||||
num_blocks = 6 # 1 null block → 5 usable
|
||||
num_tokens = block_size * 2 # 32 tokens = exactly 2 blocks
|
||||
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=3, # Allow multiple requests
|
||||
max_num_seqs=3,
|
||||
max_num_batched_tokens=200,
|
||||
num_blocks=6, # Very limited blocks to force memory pressure
|
||||
block_size=16, # Standard block size
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
# Create initial low-priority requests that will consume most memory
|
||||
low_priority_requests = create_requests_with_priority(
|
||||
num_requests=2,
|
||||
priorities=[5, 5], # Low priority
|
||||
arrival_times=[1.0, 2.0],
|
||||
num_tokens=30, # Large enough to consume significant memory,
|
||||
req_ids=["lo1", "lo2"],
|
||||
)
|
||||
|
||||
# Add and schedule low priority requests
|
||||
for request in low_priority_requests:
|
||||
scheduler.add_request(request)
|
||||
# --- Phase 1: low-priority request starts running ---
|
||||
lo1 = create_requests_with_priority(
|
||||
num_requests=1,
|
||||
priorities=[5],
|
||||
arrival_times=[1.0],
|
||||
num_tokens=num_tokens,
|
||||
req_ids=["lo1"],
|
||||
)[0]
|
||||
scheduler.add_request(lo1)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 2
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
|
||||
# Simulate model execution to move requests to running state
|
||||
# Decode: lo1 now has 33 tokens (crosses 32-token boundary).
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in low_priority_requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
req_ids=["lo1"],
|
||||
req_id_to_index={"lo1": 0},
|
||||
sampled_token_ids=[[100]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# Verify both requests are running
|
||||
assert len(scheduler.running) == 2
|
||||
|
||||
# Now add a high-priority request that requires memory allocation
|
||||
# This should trigger preemption due to memory constraints
|
||||
high_priority_request = create_requests_with_priority(
|
||||
# --- Phase 2: high-priority request arrives AFTER lo1 is running ---
|
||||
hi1 = create_requests_with_priority(
|
||||
num_requests=1,
|
||||
priorities=[0], # High priority
|
||||
arrival_times=[3.0],
|
||||
num_tokens=30, # Large enough to require significant memory
|
||||
priorities=[0],
|
||||
arrival_times=[2.0],
|
||||
num_tokens=num_tokens,
|
||||
req_ids=["hi1"],
|
||||
)[0]
|
||||
scheduler.add_request(hi1)
|
||||
|
||||
scheduler.add_request(high_priority_request)
|
||||
# schedule(): lo1 gets its 3rd block (3 used), hi1 admitted (5 used,
|
||||
# 0 free). Both are now running.
|
||||
output = scheduler.schedule()
|
||||
assert any(r.req_id == "hi1" for r in output.scheduled_new_reqs)
|
||||
assert len(scheduler.running) == 2
|
||||
|
||||
# Schedule again - this should trigger
|
||||
# preemption when trying to allocate memory
|
||||
# Decode: lo1 → 34 tokens, hi1 → 33 tokens.
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=["lo1", "hi1"],
|
||||
req_id_to_index={"lo1": 0, "hi1": 1},
|
||||
sampled_token_ids=[[101], [100]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(output, model_output)
|
||||
|
||||
# --- Phase 3: preemption with mixed-priority running requests ---
|
||||
# hi1 needs a 3rd block but 0 are free. The scheduler picks the
|
||||
# lowest-priority running request to preempt:
|
||||
# max(running, key=(priority, arrival_time)) → lo1 (5 > 0).
|
||||
output = scheduler.schedule()
|
||||
|
||||
# Due to the scheduler's design, if preemption happens
|
||||
# during running request scheduling,
|
||||
# waiting requests won't be scheduled in the same step
|
||||
# Let's check if preemption occurred by looking at the waiting queue
|
||||
|
||||
# If preemption happened, we should see requests in the
|
||||
# waiting queue
|
||||
if len(scheduler.waiting) > 1: # high priority + preempted request
|
||||
# Preemption occurred - verify the high priority request
|
||||
# gets scheduled next
|
||||
output2 = scheduler.schedule()
|
||||
assert len(output2.scheduled_new_reqs) == 1
|
||||
# High priority request
|
||||
assert output2.scheduled_new_reqs[0].req_id == "hi1"
|
||||
else:
|
||||
# No preemption needed - all requests fit
|
||||
# This is also valid behavior if memory allows
|
||||
assert len(output.scheduled_new_reqs) == 1
|
||||
# High priority request
|
||||
assert output.scheduled_new_reqs[0].req_id == "hi1"
|
||||
lo1_req = scheduler.requests["lo1"]
|
||||
assert lo1_req.status == RequestStatus.PREEMPTED, (
|
||||
"Expected low-priority 'lo1' to be preempted"
|
||||
)
|
||||
assert any(req.request_id == "hi1" for req in scheduler.running), (
|
||||
"High-priority 'hi1' should still be running"
|
||||
)
|
||||
|
||||
|
||||
def test_priority_scheduling_no_preemption_when_space_available():
|
||||
|
||||
Reference in New Issue
Block a user