[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Wentao Ye
2026-03-11 00:25:30 -04:00
committed by GitHub
parent 42fadebecb
commit a8ff2cca92
6 changed files with 243 additions and 105 deletions

View File

@@ -1115,12 +1115,16 @@ def _step_until_done(
all_finished = all_done all_finished = all_done
def _num_waiting_requests(scheduler: Scheduler) -> int:
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cycle.""" """Cycle requests through a KV transfer cycle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS # Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids) assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values(): for req in scheduler.requests.values():
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
# Simulate KV transfer completion using KVConnectorOutput.finished_recving # Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule() output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids) assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput( MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# All can be scheduled - 1st token. # All can be scheduled - 1st token.
output = scheduler.schedule() output = scheduler.schedule()
if is_async: if is_async:
assert len(scheduler.waiting) == 2 assert _num_waiting_requests(scheduler) == 2
assert scheduler.running == [] assert scheduler.running == []
_step_until_kv_transfer_finished(scheduler, req_ids) _step_until_kv_transfer_finished(scheduler, req_ids)
output = scheduler.schedule() output = scheduler.schedule()
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# This will have a local and remote cache hit. # This will have a local and remote cache hit.
output = scheduler.schedule() output = scheduler.schedule()
if is_async: if is_async:
waiting_req_ids = [req.request_id for req in scheduler.waiting] waiting_req_ids = [
req.request_id
for req in scheduler.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
]
assert len(waiting_req_ids) == 1 assert len(waiting_req_ids) == 1
_step_until_kv_transfer_finished(scheduler, waiting_req_ids) _step_until_kv_transfer_finished(scheduler, waiting_req_ids)
output = scheduler.schedule() output = scheduler.schedule()
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0 assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
# simulate first 2 waiting requests are waiting for remote KVs # simulate first 2 waiting requests are waiting for remote KVs
for req in expected_waiting_reqs[:2]: for req in expected_waiting_reqs[:2]:
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
for req in expected_waiting_reqs[:2]:
scheduler.skipped_waiting.add_request(req)
# schedule step # schedule step
# expect the first 2 waiting to be skipped, the third running, # expect the first 2 waiting to be skipped, the third running,
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
expected_waiting_reqs.pop(2) expected_waiting_reqs.pop(2)
# verify waiting order is preserved # verify waiting order is preserved
assert list(scheduler.waiting) == expected_waiting_reqs waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
assert waiting_reqs == expected_waiting_reqs
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
scheduler = create_scheduler(max_num_seqs=1)
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
requests = create_requests(num_requests=4)
for request in requests:
scheduler.add_request(request)
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
# simulate two FSM requests at the waiting head that become ready now.
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
req_fsm_1.structured_output_request = Mock(grammar=object())
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
req_fsm_2.structured_output_request = Mock(grammar=object())
# simulate a remote-KV request that is ready to be promoted now.
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
scheduler.skipped_waiting.add_request(req_fsm_1)
scheduler.skipped_waiting.add_request(req_fsm_2)
scheduler.skipped_waiting.add_request(req_remote)
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
scheduler._update_waiting_for_remote_kv = Mock()
output = scheduler.schedule()
assert output.scheduled_new_reqs
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
waiting_req_ids = [
req.request_id
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
]
assert waiting_req_ids == [
req_fsm_2.request_id,
req_remote.request_id,
req_tail.request_id,
]
def test_fcfs_mixed_skipped_waiting_types_keep_order():
scheduler = create_scheduler(max_num_batched_tokens=20)
scheduler._update_waiting_for_remote_kv = Mock()
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
)[0]
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
req_fsm.status = RequestStatus.WAITING_FOR_FSM
req_fsm.structured_output_request = Mock(grammar=None)
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
scheduler.add_request(req)
scheduler.schedule()
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
assert not scheduler.running
req_fsm.structured_output_request = Mock(grammar=object())
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
req_stream.status = RequestStatus.WAITING
second_output = scheduler.schedule()
expected_order = [
req_fsm.request_id,
req_remote.request_id,
req_stream.request_id,
req_tail.request_id,
]
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
assert [req.request_id for req in scheduler.running] == expected_order
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
def test_abort_request_waiting_for_remote_kvs(): def test_abort_request_waiting_for_remote_kvs():

View File

@@ -119,7 +119,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
scheduler_output = fail_scheduler.schedule() scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1 assert len(fail_scheduler.skipped_waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == num_external_computed_tokens assert request.num_computed_tokens == num_external_computed_tokens
@@ -145,3 +145,4 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
assert output.finish_reason == FinishReason.ERROR assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0 assert len(fail_scheduler.waiting) == 0
assert len(fail_scheduler.skipped_waiting) == 0

View File

@@ -337,7 +337,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
scheduler_output = recompute_scheduler.schedule() scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs # request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1 assert len(recompute_scheduler.skipped_waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == num_external_computed_tokens assert request.num_computed_tokens == num_external_computed_tokens

View File

@@ -76,8 +76,9 @@ def test_async_load_failure(
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 3 assert len(scheduler.waiting) == 0
for request in scheduler.waiting: assert len(scheduler.skipped_waiting) == 3
for request in scheduler.skipped_waiting:
assert request.num_computed_tokens == num_external_computed_tokens assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
@@ -96,8 +97,9 @@ def test_async_load_failure(
min_invalid_block_idx = min(invalid_block_idxs) min_invalid_block_idx = min(invalid_block_idxs)
assert len(scheduler.waiting) == 3 assert len(scheduler.waiting) == 0
for request in scheduler.waiting: assert len(scheduler.skipped_waiting) == 3
for request in scheduler.skipped_waiting:
if request.request_id == request2.request_id: if request.request_id == request2.request_id:
assert request.num_computed_tokens == ( assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size min_invalid_block_idx * scheduler.block_size
@@ -303,8 +305,9 @@ def test_async_progressive_load_failure(
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 0
assert scheduler.waiting.peek_request().request_id == request.request_id assert len(scheduler.skipped_waiting) == 1
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == num_external_computed_tokens assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
@@ -325,8 +328,9 @@ def test_async_progressive_load_failure(
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx) min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
assert len(scheduler.waiting) == 1 assert len(scheduler.waiting) == 0
assert scheduler.waiting.peek_request().request_id == request.request_id assert len(scheduler.skipped_waiting) == 1
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == ( assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size min_invalid_block_idx * scheduler.block_size
) )

View File

@@ -18,6 +18,10 @@ from .utils import (
pytestmark = pytest.mark.cpu_test pytestmark = pytest.mark.cpu_test
def _num_waiting_requests(scheduler) -> int:
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
def test_basic_lifecycle(): def test_basic_lifecycle():
"""Test lifecycle of a remote prefill.""" """Test lifecycle of a remote prefill."""
@@ -54,8 +58,8 @@ def test_basic_lifecycle():
assert scheduler_output.total_num_scheduled_tokens == 0 assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ... # Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert request in scheduler.waiting assert request in scheduler.skipped_waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == NUM_TOKENS assert request.num_computed_tokens == NUM_TOKENS
@@ -81,7 +85,7 @@ def test_basic_lifecycle():
# STEP (2): # STEP (2):
# (2a): schedule(): nothing happens! # (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv. # (2b): forward(): request finishes recv.
@@ -94,7 +98,7 @@ def test_basic_lifecycle():
engine_core_outputs = scheduler.update_from_output( engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output scheduler_output, model_runner_output
) )
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert request_id in scheduler.finished_recving_kv_req_ids assert request_id in scheduler.finished_recving_kv_req_ids
# STEP (3): # STEP (3):
@@ -180,7 +184,7 @@ def test_interleaved_lifecycle():
scheduler.add_request(request_remote) scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
@@ -190,7 +194,7 @@ def test_interleaved_lifecycle():
# STEP 3: continue running, KVs not arrived yet. # STEP 3: continue running, KVs not arrived yet.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -199,14 +203,14 @@ def test_interleaved_lifecycle():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive. # STEP 4: KVs arrive.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0 assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -218,7 +222,7 @@ def test_interleaved_lifecycle():
# STEP 5: RECVed KVs are sent to ModelRunner. # STEP 5: RECVed KVs are sent to ModelRunner.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 3 assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -279,14 +283,14 @@ def test_no_spurious_prefix_caching():
scheduler.add_request(request_remote) scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Schedule the local prefill request. This should # Schedule the local prefill request. This should
# cause blocks to be cached, but separately from # cause blocks to be cached, but separately from
scheduler.add_request(request_local) scheduler.add_request(request_local)
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0 0
@@ -348,7 +352,7 @@ def test_full_block_prompt():
finished_recving={request_id} finished_recving={request_id}
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert request_id in scheduler.finished_recving_kv_req_ids assert request_id in scheduler.finished_recving_kv_req_ids
# # STEP (3): Run as usual. # # STEP (3): Run as usual.
@@ -418,7 +422,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal]) model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
# Step 2: 5 blocks are in use (2 new for remote blocks). # Step 2: 5 blocks are in use (2 new for remote blocks).
scheduler.add_request(request_remote) scheduler.add_request(request_remote)
@@ -426,7 +430,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal]) model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 3: finish recving (5 blocks in use) # Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
@@ -435,7 +439,7 @@ def test_cannot_schedule_after_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 4: try to schedule, remote request is put to running list # Step 4: try to schedule, remote request is put to running list
# because the transfer is completed. # because the transfer is completed.
@@ -445,7 +449,7 @@ def test_cannot_schedule_after_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2 assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
# Step 5: Remote request will be put back to waiting list # Step 5: Remote request will be put back to waiting list
# because it needs new block to hold generated token. # because it needs new block to hold generated token.
@@ -453,7 +457,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal]) model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 6: finish the request, free it. # Step 6: finish the request, free it.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
@@ -462,7 +466,7 @@ def test_cannot_schedule_after_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 7: now we can schedule (with 2 blocks computed), # Step 7: now we can schedule (with 2 blocks computed),
# request is retrieved from preempted list. # request is retrieved from preempted list.
@@ -474,7 +478,7 @@ def test_cannot_schedule_after_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
# Step 8: free everything. # Step 8: free everything.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
@@ -521,7 +525,7 @@ def test_cannot_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal]) model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
# Step 2: 3 blocks are in use, # Step 2: 3 blocks are in use,
# need 3 new for remote blocks but only 2 are available. # need 3 new for remote blocks but only 2 are available.
@@ -530,7 +534,7 @@ def test_cannot_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal]) model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Should not have KV transfer in progress. # Should not have KV transfer in progress.
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
@@ -541,14 +545,14 @@ def test_cannot_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 4: now we can initiate KV transfer (with 2 blocks computed). # Step 4: now we can initiate KV transfer (with 2 blocks computed).
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[]) model_runner_output = create_model_runner_output(reqs=[])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 5: finish recving (5 blocks in use) # Step 5: finish recving (5 blocks in use)
@@ -558,14 +562,14 @@ def test_cannot_recv():
) )
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0 assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1 assert _num_waiting_requests(scheduler) == 1
# Step 6: schedule remote request # Step 6: schedule remote request
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote]) model_runner_output = create_model_runner_output(reqs=[request_remote])
scheduler.update_from_output(scheduler_output, model_runner_output) scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1 assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0 assert _num_waiting_requests(scheduler) == 0
# Step 7: free everything. # Step 7: free everything.
scheduler_output = scheduler.schedule() scheduler_output = scheduler.schedule()

View File

@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
NewRequestData, NewRequestData,
SchedulerOutput, SchedulerOutput,
) )
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue from vllm.v1.core.sched.request_queue import (
RequestQueue,
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
) from e ) from e
# Priority queues for requests. # Priority queues for requests.
self.waiting = create_request_queue(self.policy) self.waiting = create_request_queue(self.policy)
# requests skipped in waiting flow due async deps or constraints.
self.skipped_waiting = create_request_queue(self.policy)
self.running: list[Request] = [] self.running: list[Request] = []
# The request IDs that are finished in between the previous and the # The request IDs that are finished in between the previous and the
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be step_skipped_waiting = create_request_queue(self.policy)
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
while self.waiting and token_budget > 0: while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
break break
request = self.waiting.peek_request() request_queue = self._select_waiting_queue_for_scheduling()
assert request_queue is not None
request = request_queue.peek_request()
request_id = request.request_id request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs. # try to promote blocked statuses while traversing skipped queue.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: if self._is_blocked_waiting_status(
is_ready = self._update_waiting_for_remote_kv(request) request.status
if is_ready: ) and not self._try_promote_blocked_waiting_request(request):
if request.num_preemptions: if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
logger.debug( logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.", "%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id, request_id,
) )
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
) )
): ):
# Scheduling would exceed max_loras, skip. # Scheduling would exceed max_loras, skip.
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue continue
num_external_computed_tokens = 0 num_external_computed_tokens = 0
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled because # The request cannot be scheduled because
# the KVConnector couldn't determine # the KVConnector couldn't determine
# the number of matched tokens. # the number of matched tokens.
self.waiting.pop_request() request_queue.pop_request()
skipped_waiting_requests.prepend_request(request) step_skipped_waiting.prepend_request(request)
continue continue
request.num_external_computed_tokens = ext_tokens request.num_external_computed_tokens = ext_tokens
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
preempted=request.num_preemptions > 0, preempted=request.num_preemptions > 0,
) )
# Request was already popped from self.waiting request = request_queue.pop_request()
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async: if load_kv_async:
# If loading async, allocate memory and put request # If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state. # into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
step_skipped_waiting.prepend_request(request)
# Set num_computed_tokens even though KVs are not yet loaded. # Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until # request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer. # the request finished the KV transfer.
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
if self.ec_connector is not None: if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i) self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue # re-queue requests skipped in this pass ahead of older skipped items.
if skipped_waiting_requests: if step_skipped_waiting:
self.waiting.prepend_requests(skipped_waiting_requests) self.skipped_waiting.prepend_requests(step_skipped_waiting)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs return engine_core_outputs
@staticmethod
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
return status in (
RequestStatus.WAITING_FOR_FSM,
RequestStatus.WAITING_FOR_REMOTE_KVS,
RequestStatus.WAITING_FOR_STREAMING_REQ,
)
def _enqueue_waiting_request(self, request: Request) -> None:
if self._is_blocked_waiting_status(request.status):
self.skipped_waiting.add_request(request)
else:
self.waiting.add_request(request)
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
if self.policy == SchedulingPolicy.FCFS:
return self.skipped_waiting or self.waiting or None
# PRIORITY mode: compare queue heads when both queues are non-empty.
if self.waiting and self.skipped_waiting:
waiting_req = self.waiting.peek_request()
skipped_req = self.skipped_waiting.peek_request()
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
return self.waiting or self.skipped_waiting or None
def _handle_stopped_request(self, request: Request) -> bool: def _handle_stopped_request(self, request: Request) -> bool:
"""Return True if finished (can be False for resumable requests).""" """Return True if finished (can be False for resumable requests)."""
if not request.resumable: if not request.resumable:
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
self.num_waiting_for_streaming_input += 1 self.num_waiting_for_streaming_input += 1
self.waiting.add_request(request) self._enqueue_waiting_request(request)
return False return False
def _get_routed_experts(self, request: Request) -> np.ndarray | None: def _get_routed_experts(self, request: Request) -> np.ndarray | None:
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
def get_request_counts(self) -> tuple[int, int]: def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs).""" """Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting) return len(self.running), len(self.waiting) + len(self.skipped_waiting)
def add_request(self, request: Request) -> None: def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id) existing = self.requests.get(request.request_id)
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
else: else:
if request.resumable: if request.resumable:
request.streaming_queue = deque() request.streaming_queue = deque()
self.waiting.add_request(request) self._enqueue_waiting_request(request)
self.requests[request.request_id] = request self.requests[request.request_id] = request
if self.log_stats: if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED) request.record_event(EngineCoreEventType.QUEUED)
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
self.running = remove_all(self.running, running_requests_to_remove) self.running = remove_all(self.running, running_requests_to_remove)
if waiting_requests_to_remove: if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove) self.waiting.remove_requests(waiting_requests_to_remove)
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
# Second pass: set status and free requests # Second pass: set status and free requests
for request in valid_requests: for request in valid_requests:
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
return 0 return 0
if self._pause_state == PauseState.PAUSED_NEW: if self._pause_state == PauseState.PAUSED_NEW:
return len(self.running) return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input num_waiting = (
len(self.waiting)
+ len(self.skipped_waiting)
- self.num_waiting_for_streaming_input
)
return num_waiting + len(self.running) return num_waiting + len(self.running)
def has_finished_requests(self) -> bool: def has_finished_requests(self) -> bool:
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
) )
return SchedulerStats( return SchedulerStats(
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting), num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
kv_cache_usage=self.kv_cache_manager.usage, kv_cache_usage=self.kv_cache_manager.usage,
encoder_cache_usage=self._get_encoder_cache_usage(), encoder_cache_usage=self._get_encoder_cache_usage(),
prefix_cache_stats=prefix_cache_stats, prefix_cache_stats=prefix_cache_stats,
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
return self.connector.request_finished_all_groups(request, block_ids) return self.connector.request_finished_all_groups(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool: def _update_waiting_for_remote_kv(self, request: Request) -> None:
""" """
KV Connector: check if the request_id is finished_recving. KV Connector: update request state after async recv is finished.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
When the kv transfer is ready, we cache the blocks When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV. WAITING_FOR_REMOTE_KV.
""" """
assert self.connector is not None assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
if request.request_id in self.failed_recving_kv_req_ids: if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already # Request had KV load failures; num_computed_tokens was already
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
if request.num_cached_tokens < 0: if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens request.num_cached_tokens = request.num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id) self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
"""
Try to promote a blocked waiting request back to schedulable states.
"""
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# finished_recving_kv_req_ids is populated during
# update_from_output(), based on worker-side connector signals
# in KVConnectorOutput.finished_recving
if request.request_id not in self.finished_recving_kv_req_ids:
return False
self._update_waiting_for_remote_kv(request)
if request.num_preemptions:
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if not (structured_output_req and structured_output_req.grammar):
return False
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
return False
raise AssertionError(
"Unexpected blocked waiting status in promotion: "
f"{request.status.name} for request {request.request_id}"
)
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
""" """
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
# handle async KV loads (not cached yet, evict_blocks=False) # handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = ( async_load_reqs = (
req req
for req in self.waiting for req in self.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
) )
async_failed_req_ids, num_failed_tokens, _ = ( async_failed_req_ids, num_failed_tokens, _ = (