[Sched] Enhance the logic to remove stopped requests from queues (#20739)
This commit is contained in:
@@ -451,6 +451,7 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
@@ -504,6 +505,7 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
@@ -556,6 +558,7 @@ def test_stop_via_update_from_output():
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
@@ -703,6 +706,65 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
|
||||
def test_preempt_during_execution():
|
||||
# NOTE(woosuk): The actual number of available blocks is 10 instead of 11
|
||||
# because block 0 is reserved as the null block.
|
||||
scheduler = create_scheduler(max_num_batched_tokens=100,
|
||||
block_size=16,
|
||||
num_blocks=11,
|
||||
enable_prefix_caching=False)
|
||||
requests = create_requests(num_requests=2, num_tokens=80)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
scheduler_output0 = scheduler.schedule()
|
||||
assert len(scheduler_output0.num_scheduled_tokens) == 1
|
||||
assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5
|
||||
|
||||
# Schedule the second request while the first request is still running.
|
||||
# This scenario can occur in certain cases, when max_concurrent_batches > 1
|
||||
# (e.g., when pipeline parallelism is used).
|
||||
scheduler.add_request(requests[1])
|
||||
scheduler_output1 = scheduler.schedule()
|
||||
assert len(scheduler_output1.num_scheduled_tokens) == 1
|
||||
assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5
|
||||
|
||||
# Get the output of the first request.
|
||||
model_runner_output0 = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output0, model_runner_output0)
|
||||
|
||||
# Schedule the first request again. This will cause the preemption
|
||||
# of the second request because the KV cache is full.
|
||||
_ = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0] == requests[0]
|
||||
assert requests[1].status == RequestStatus.PREEMPTED
|
||||
|
||||
model_runner_output1 = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[42]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output1)
|
||||
|
||||
# The second request (that is preempted) should be updated with the
|
||||
# sampled token id.
|
||||
assert len(requests[1].output_token_ids) == 1
|
||||
assert requests[1].output_token_ids[0] == 42
|
||||
|
||||
|
||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
|
||||
Reference in New Issue
Block a user