diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bbeca6ef7..2fe452421 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1115,12 +1115,16 @@ def _step_until_done( all_finished = all_done +def _num_waiting_requests(scheduler: Scheduler) -> int: + return len(scheduler.waiting) + len(scheduler.skipped_waiting) + + def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): """Cycle requests through a KV transfer cycle.""" # Requests should first transition to WAITING_FOR_REMOTE_KVS output = scheduler.schedule() - assert len(scheduler.waiting) == len(req_ids) + assert _num_waiting_requests(scheduler) == len(req_ids) assert len(scheduler.running) == 0 assert len(output.scheduled_new_reqs) == 0 for req in scheduler.requests.values(): @@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]): # Simulate KV transfer completion using KVConnectorOutput.finished_recving output = scheduler.schedule() - assert len(scheduler.waiting) == len(req_ids) + assert _num_waiting_requests(scheduler) == len(req_ids) assert len(scheduler.running) == 0 MODEL_RUNNER_OUTPUT = ModelRunnerOutput( @@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): # All can be scheduled - 1st token. output = scheduler.schedule() if is_async: - assert len(scheduler.waiting) == 2 + assert _num_waiting_requests(scheduler) == 2 assert scheduler.running == [] _step_until_kv_transfer_finished(scheduler, req_ids) output = scheduler.schedule() @@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): # This will have a local and remote cache hit. output = scheduler.schedule() if is_async: - waiting_req_ids = [req.request_id for req in scheduler.waiting] + waiting_req_ids = [ + req.request_id + for req in scheduler.skipped_waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS + ] assert len(waiting_req_ids) == 1 _step_until_kv_transfer_finished(scheduler, waiting_req_ids) output = scheduler.schedule() @@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request(): output = scheduler.schedule() assert len(output.scheduled_new_reqs) == 0 assert len(scheduler.running) == 0 - assert len(scheduler.waiting) == 1 + assert len(scheduler.waiting) == 0 + assert len(scheduler.skipped_waiting) == 1 @pytest.mark.parametrize( @@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order(): # simulate first 2 waiting requests are waiting for remote KVs for req in expected_waiting_reqs[:2]: req.status = RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.waiting.remove_requests(expected_waiting_reqs[:2]) + for req in expected_waiting_reqs[:2]: + scheduler.skipped_waiting.add_request(req) # schedule step # expect the first 2 waiting to be skipped, the third running, @@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order(): expected_waiting_reqs.pop(2) # verify waiting order is preserved - assert list(scheduler.waiting) == expected_waiting_reqs + waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting) + assert waiting_reqs == expected_waiting_reqs + + +def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix(): + scheduler = create_scheduler(max_num_seqs=1) + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.return_value = (0, False) + + requests = create_requests(num_requests=4) + for request in requests: + scheduler.add_request(request) + + req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting) + + # simulate two FSM requests at the waiting head that become ready now. + req_fsm_1.status = RequestStatus.WAITING_FOR_FSM + req_fsm_1.structured_output_request = Mock(grammar=object()) + req_fsm_2.status = RequestStatus.WAITING_FOR_FSM + req_fsm_2.structured_output_request = Mock(grammar=object()) + + # simulate a remote-KV request that is ready to be promoted now. + req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote]) + scheduler.skipped_waiting.add_request(req_fsm_1) + scheduler.skipped_waiting.add_request(req_fsm_2) + scheduler.skipped_waiting.add_request(req_remote) + scheduler.finished_recving_kv_req_ids.add(req_remote.request_id) + scheduler._update_waiting_for_remote_kv = Mock() + + output = scheduler.schedule() + + assert output.scheduled_new_reqs + assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id + waiting_req_ids = [ + req.request_id + for req in list(scheduler.skipped_waiting) + list(scheduler.waiting) + ] + assert waiting_req_ids == [ + req_fsm_2.request_id, + req_remote.request_id, + req_tail.request_id, + ] + + +def test_fcfs_mixed_skipped_waiting_types_keep_order(): + scheduler = create_scheduler(max_num_batched_tokens=20) + scheduler._update_waiting_for_remote_kv = Mock() + + mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731 + num_requests=1, num_tokens=num_tokens, req_ids=[req_id] + )[0] + req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream") + req_regular, req_tail = mk_req("regular", 20), mk_req("tail") + req_fsm.status = RequestStatus.WAITING_FOR_FSM + req_fsm.structured_output_request = Mock(grammar=None) + req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS + req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ + + for req in (req_fsm, req_remote, req_stream, req_regular, req_tail): + scheduler.add_request(req) + scheduler.schedule() + assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream] + + scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED) + assert not scheduler.running + + req_fsm.structured_output_request = Mock(grammar=object()) + scheduler.finished_recving_kv_req_ids.add(req_remote.request_id) + req_stream.status = RequestStatus.WAITING + + second_output = scheduler.schedule() + expected_order = [ + req_fsm.request_id, + req_remote.request_id, + req_stream.request_id, + req_tail.request_id, + ] + assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order + assert [req.request_id for req in scheduler.running] == expected_order + scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote) def test_abort_request_waiting_for_remote_kvs(): diff --git a/tests/v1/kv_connector/unit/test_error_propagation.py b/tests/v1/kv_connector/unit/test_error_propagation.py index 11286611e..a07364cd3 100644 --- a/tests/v1/kv_connector/unit/test_error_propagation.py +++ b/tests/v1/kv_connector/unit/test_error_propagation.py @@ -119,7 +119,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler): scheduler_output = fail_scheduler.schedule() - assert len(fail_scheduler.waiting) == 1 + assert len(fail_scheduler.skipped_waiting) == 1 assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.num_computed_tokens == num_external_computed_tokens @@ -145,3 +145,4 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler): assert output.finish_reason == FinishReason.ERROR assert len(fail_scheduler.waiting) == 0 + assert len(fail_scheduler.skipped_waiting) == 0 diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py index 53fe59984..77d629729 100644 --- a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -337,7 +337,7 @@ def test_async_recompute_blocks_not_cached_when_invalid( scheduler_output = recompute_scheduler.schedule() # request should be waiting for remote KVs - assert len(recompute_scheduler.waiting) == 1 + assert len(recompute_scheduler.skipped_waiting) == 1 assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.num_computed_tokens == num_external_computed_tokens diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py index fcdb2869d..4f35527b0 100644 --- a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -76,8 +76,9 @@ def test_async_load_failure( scheduler_output = scheduler.schedule() - assert len(scheduler.waiting) == 3 - for request in scheduler.waiting: + assert len(scheduler.waiting) == 0 + assert len(scheduler.skipped_waiting) == 3 + for request in scheduler.skipped_waiting: assert request.num_computed_tokens == num_external_computed_tokens assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 @@ -96,8 +97,9 @@ def test_async_load_failure( min_invalid_block_idx = min(invalid_block_idxs) - assert len(scheduler.waiting) == 3 - for request in scheduler.waiting: + assert len(scheduler.waiting) == 0 + assert len(scheduler.skipped_waiting) == 3 + for request in scheduler.skipped_waiting: if request.request_id == request2.request_id: assert request.num_computed_tokens == ( min_invalid_block_idx * scheduler.block_size @@ -303,8 +305,9 @@ def test_async_progressive_load_failure( scheduler_output = scheduler.schedule() - assert len(scheduler.waiting) == 1 - assert scheduler.waiting.peek_request().request_id == request.request_id + assert len(scheduler.waiting) == 0 + assert len(scheduler.skipped_waiting) == 1 + assert scheduler.skipped_waiting.peek_request().request_id == request.request_id assert request.num_computed_tokens == num_external_computed_tokens assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 @@ -325,8 +328,9 @@ def test_async_progressive_load_failure( min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx) - assert len(scheduler.waiting) == 1 - assert scheduler.waiting.peek_request().request_id == request.request_id + assert len(scheduler.waiting) == 0 + assert len(scheduler.skipped_waiting) == 1 + assert scheduler.skipped_waiting.peek_request().request_id == request.request_id assert request.num_computed_tokens == ( min_invalid_block_idx * scheduler.block_size ) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f0ff216be..f48dc0fff 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -18,6 +18,10 @@ from .utils import ( pytestmark = pytest.mark.cpu_test +def _num_waiting_requests(scheduler) -> int: + return len(scheduler.waiting) + len(scheduler.skipped_waiting) + + def test_basic_lifecycle(): """Test lifecycle of a remote prefill.""" @@ -54,8 +58,8 @@ def test_basic_lifecycle(): assert scheduler_output.total_num_scheduled_tokens == 0 # Req waiting for KVs with no computed/scheduled toks ... - assert len(scheduler.waiting) == 1 - assert request in scheduler.waiting + assert _num_waiting_requests(scheduler) == 1 + assert request in scheduler.skipped_waiting assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert request.num_computed_tokens == NUM_TOKENS @@ -81,7 +85,7 @@ def test_basic_lifecycle(): # STEP (2): # (2a): schedule(): nothing happens! scheduler_output = scheduler.schedule() - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert len(scheduler.running) == 0 # (2b): forward(): request finishes recv. @@ -94,7 +98,7 @@ def test_basic_lifecycle(): engine_core_outputs = scheduler.update_from_output( scheduler_output, model_runner_output ) - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert request_id in scheduler.finished_recving_kv_req_ids # STEP (3): @@ -180,7 +184,7 @@ def test_interleaved_lifecycle(): scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert len(scheduler_output.scheduled_new_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 1 @@ -190,7 +194,7 @@ def test_interleaved_lifecycle(): # STEP 3: continue running, KVs not arrived yet. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 @@ -199,14 +203,14 @@ def test_interleaved_lifecycle(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 # STEP 4: KVs arrive. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert len(scheduler_output.scheduled_new_reqs) == 0 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 @@ -218,7 +222,7 @@ def test_interleaved_lifecycle(): # STEP 5: RECVed KVs are sent to ModelRunner. scheduler_output = scheduler.schedule() assert len(scheduler.running) == 3 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 assert len(scheduler_output.scheduled_new_reqs) == 1 assert scheduler_output.scheduled_cached_reqs.num_reqs == 2 @@ -279,14 +283,14 @@ def test_no_spurious_prefix_caching(): scheduler.add_request(request_remote) scheduler_output = scheduler.schedule() scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Schedule the local prefill request. This should # cause blocks to be cached, but separately from scheduler.add_request(request_local) scheduler_output = scheduler.schedule() assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ 0 @@ -348,7 +352,7 @@ def test_full_block_prompt(): finished_recving={request_id} ) scheduler.update_from_output(scheduler_output, model_runner_output) - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert request_id in scheduler.finished_recving_kv_req_ids # # STEP (3): Run as usual. @@ -418,7 +422,7 @@ def test_cannot_schedule_after_recv(): model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 # Step 2: 5 blocks are in use (2 new for remote blocks). scheduler.add_request(request_remote) @@ -426,7 +430,7 @@ def test_cannot_schedule_after_recv(): model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 3: finish recving (5 blocks in use) scheduler_output = scheduler.schedule() @@ -435,7 +439,7 @@ def test_cannot_schedule_after_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 4: try to schedule, remote request is put to running list # because the transfer is completed. @@ -445,7 +449,7 @@ def test_cannot_schedule_after_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 # Step 5: Remote request will be put back to waiting list # because it needs new block to hold generated token. @@ -453,7 +457,7 @@ def test_cannot_schedule_after_recv(): model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 6: finish the request, free it. scheduler_output = scheduler.schedule() @@ -462,7 +466,7 @@ def test_cannot_schedule_after_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 7: now we can schedule (with 2 blocks computed), # request is retrieved from preempted list. @@ -474,7 +478,7 @@ def test_cannot_schedule_after_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 # Step 8: free everything. scheduler_output = scheduler.schedule() @@ -521,7 +525,7 @@ def test_cannot_recv(): model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 # Step 2: 3 blocks are in use, # need 3 new for remote blocks but only 2 are available. @@ -530,7 +534,7 @@ def test_cannot_recv(): model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Should not have KV transfer in progress. assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS @@ -541,14 +545,14 @@ def test_cannot_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 4: now we can initiate KV transfer (with 2 blocks computed). scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS # Step 5: finish recving (5 blocks in use) @@ -558,14 +562,14 @@ def test_cannot_recv(): ) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 0 - assert len(scheduler.waiting) == 1 + assert _num_waiting_requests(scheduler) == 1 # Step 6: schedule remote request scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 - assert len(scheduler.waiting) == 0 + assert _num_waiting_requests(scheduler) == 0 # Step 7: free everything. scheduler_output = scheduler.schedule() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3487fe308..4628e6344 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import ( NewRequestData, SchedulerOutput, ) -from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue +from vllm.v1.core.sched.request_queue import ( + RequestQueue, + SchedulingPolicy, + create_request_queue, +) from vllm.v1.core.sched.utils import check_stop, remove_all from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs from vllm.v1.kv_cache_interface import KVCacheConfig @@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface): ) from e # Priority queues for requests. self.waiting = create_request_queue(self.policy) + # requests skipped in waiting flow due async deps or constraints. + self.skipped_waiting = create_request_queue(self.policy) self.running: list[Request] = [] # The request IDs that are finished in between the previous and the @@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface): # Next, schedule the WAITING requests. if not preempted_reqs and self._pause_state == PauseState.UNPAUSED: - # Use a temporary RequestQueue to collect requests that need to be - # skipped and put back at the head of the waiting queue later - skipped_waiting_requests = create_request_queue(self.policy) + step_skipped_waiting = create_request_queue(self.policy) - while self.waiting and token_budget > 0: + while (self.waiting or self.skipped_waiting) and token_budget > 0: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting.peek_request() + request_queue = self._select_waiting_queue_for_scheduling() + assert request_queue is not None + + request = request_queue.peek_request() request_id = request.request_id - # KVTransfer: skip request if still waiting for remote kvs. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - is_ready = self._update_waiting_for_remote_kv(request) - if is_ready: - if request.num_preemptions: - # We must be loading for a resumed preemption - # rather than a new request. - request.status = RequestStatus.PREEMPTED - else: - request.status = RequestStatus.WAITING - else: + # try to promote blocked statuses while traversing skipped queue. + if self._is_blocked_waiting_status( + request.status + ) and not self._try_promote_blocked_waiting_request(request): + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", request_id, ) - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue - - # Skip request if the structured output request is still waiting - # for FSM compilation. - if request.status == RequestStatus.WAITING_FOR_FSM: - structured_output_req = request.structured_output_request - if structured_output_req and structured_output_req.grammar: - request.status = RequestStatus.WAITING - else: - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) - continue - - # Streaming: skip request if still waiting for next streaming req. - if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: - assert not request.streaming_queue - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue # Check that adding the request still respects the max_loras @@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface): ) ): # Scheduling would exceed max_loras, skip. - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue num_external_computed_tokens = 0 @@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface): # The request cannot be scheduled because # the KVConnector couldn't determine # the number of matched tokens. - self.waiting.pop_request() - skipped_waiting_requests.prepend_request(request) + request_queue.pop_request() + step_skipped_waiting.prepend_request(request) continue request.num_external_computed_tokens = ext_tokens @@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface): preempted=request.num_preemptions > 0, ) - # Request was already popped from self.waiting - # unless it was re-added above due to new_blocks being None. - request = self.waiting.pop_request() + request = request_queue.pop_request() if load_kv_async: # If loading async, allocate memory and put request # into the WAITING_FOR_REMOTE_KV state. - skipped_waiting_requests.prepend_request(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + step_skipped_waiting.prepend_request(request) # Set num_computed_tokens even though KVs are not yet loaded. # request.num_computed_tokens will not be used anywhere until # the request finished the KV transfer. @@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface): if self.ec_connector is not None: self.ec_connector.update_state_after_alloc(request, i) - # Put back any skipped requests at the head of the waiting queue - if skipped_waiting_requests: - self.waiting.prepend_requests(skipped_waiting_requests) + # re-queue requests skipped in this pass ahead of older skipped items. + if step_skipped_waiting: + self.skipped_waiting.prepend_requests(step_skipped_waiting) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface): return engine_core_outputs + @staticmethod + def _is_blocked_waiting_status(status: RequestStatus) -> bool: + return status in ( + RequestStatus.WAITING_FOR_FSM, + RequestStatus.WAITING_FOR_REMOTE_KVS, + RequestStatus.WAITING_FOR_STREAMING_REQ, + ) + + def _enqueue_waiting_request(self, request: Request) -> None: + if self._is_blocked_waiting_status(request.status): + self.skipped_waiting.add_request(request) + else: + self.waiting.add_request(request) + + def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None: + if self.policy == SchedulingPolicy.FCFS: + return self.skipped_waiting or self.waiting or None + + # PRIORITY mode: compare queue heads when both queues are non-empty. + if self.waiting and self.skipped_waiting: + waiting_req = self.waiting.peek_request() + skipped_req = self.skipped_waiting.peek_request() + return self.waiting if waiting_req < skipped_req else self.skipped_waiting + + return self.waiting or self.skipped_waiting or None + def _handle_stopped_request(self, request: Request) -> bool: """Return True if finished (can be False for resumable requests).""" if not request.resumable: @@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface): request.status = RequestStatus.WAITING_FOR_STREAMING_REQ self.num_waiting_for_streaming_input += 1 - self.waiting.add_request(request) + self._enqueue_waiting_request(request) return False def _get_routed_experts(self, request: Request) -> np.ndarray | None: @@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface): def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" - return len(self.running), len(self.waiting) + return len(self.running), len(self.waiting) + len(self.skipped_waiting) def add_request(self, request: Request) -> None: existing = self.requests.get(request.request_id) @@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface): else: if request.resumable: request.streaming_queue = deque() - self.waiting.add_request(request) + self._enqueue_waiting_request(request) self.requests[request.request_id] = request if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface): self.running = remove_all(self.running, running_requests_to_remove) if waiting_requests_to_remove: self.waiting.remove_requests(waiting_requests_to_remove) + self.skipped_waiting.remove_requests(waiting_requests_to_remove) # Second pass: set status and free requests for request in valid_requests: @@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface): return 0 if self._pause_state == PauseState.PAUSED_NEW: return len(self.running) - num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input + num_waiting = ( + len(self.waiting) + + len(self.skipped_waiting) + - self.num_waiting_for_streaming_input + ) return num_waiting + len(self.running) def has_finished_requests(self) -> bool: @@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface): ) return SchedulerStats( num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), + num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting), kv_cache_usage=self.kv_cache_manager.usage, encoder_cache_usage=self._get_encoder_cache_usage(), prefix_cache_stats=prefix_cache_stats, @@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface): return self.connector.request_finished_all_groups(request, block_ids) - def _update_waiting_for_remote_kv(self, request: Request) -> bool: + def _update_waiting_for_remote_kv(self, request: Request) -> None: """ - KV Connector: check if the request_id is finished_recving. - - The finished_recving_kv_req_ids list is populated - on the previous steps()'s update_from_output based - on the worker side connector. + KV Connector: update request state after async recv is finished. When the kv transfer is ready, we cache the blocks and the request state will be moved back to WAITING from WAITING_FOR_REMOTE_KV. """ assert self.connector is not None - if request.request_id not in self.finished_recving_kv_req_ids: - return False if request.request_id in self.failed_recving_kv_req_ids: # Request had KV load failures; num_computed_tokens was already @@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface): if request.num_cached_tokens < 0: request.num_cached_tokens = request.num_computed_tokens - # Return that we are ready. self.finished_recving_kv_req_ids.remove(request.request_id) - return True + + def _try_promote_blocked_waiting_request(self, request: Request) -> bool: + """ + Try to promote a blocked waiting request back to schedulable states. + """ + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + # finished_recving_kv_req_ids is populated during + # update_from_output(), based on worker-side connector signals + # in KVConnectorOutput.finished_recving + if request.request_id not in self.finished_recving_kv_req_ids: + return False + self._update_waiting_for_remote_kv(request) + if request.num_preemptions: + request.status = RequestStatus.PREEMPTED + else: + request.status = RequestStatus.WAITING + return True + + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if not (structured_output_req and structured_output_req.grammar): + return False + request.status = RequestStatus.WAITING + return True + + if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ: + assert not request.streaming_queue + return False + + raise AssertionError( + "Unexpected blocked waiting status in promotion: " + f"{request.status.name} for request {request.request_id}" + ) def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): """ @@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface): # handle async KV loads (not cached yet, evict_blocks=False) async_load_reqs = ( req - for req in self.waiting + for req in self.skipped_waiting if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS ) async_failed_req_ids, num_failed_tokens, _ = (