[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Wentao Ye
2026-03-11 00:25:30 -04:00
committed by GitHub
parent 42fadebecb
commit a8ff2cca92
6 changed files with 243 additions and 105 deletions

View File

@@ -1115,12 +1115,16 @@ def _step_until_done(
all_finished = all_done
def _num_waiting_requests(scheduler: Scheduler) -> int:
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cycle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values():
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# All can be scheduled - 1st token.
output = scheduler.schedule()
if is_async:
assert len(scheduler.waiting) == 2
assert _num_waiting_requests(scheduler) == 2
assert scheduler.running == []
_step_until_kv_transfer_finished(scheduler, req_ids)
output = scheduler.schedule()
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# This will have a local and remote cache hit.
output = scheduler.schedule()
if is_async:
waiting_req_ids = [req.request_id for req in scheduler.waiting]
waiting_req_ids = [
req.request_id
for req in scheduler.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
]
assert len(waiting_req_ids) == 1
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
output = scheduler.schedule()
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 1
@pytest.mark.parametrize(
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
# simulate first 2 waiting requests are waiting for remote KVs
for req in expected_waiting_reqs[:2]:
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
for req in expected_waiting_reqs[:2]:
scheduler.skipped_waiting.add_request(req)
# schedule step
# expect the first 2 waiting to be skipped, the third running,
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
expected_waiting_reqs.pop(2)
# verify waiting order is preserved
assert list(scheduler.waiting) == expected_waiting_reqs
waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
assert waiting_reqs == expected_waiting_reqs
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
scheduler = create_scheduler(max_num_seqs=1)
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
requests = create_requests(num_requests=4)
for request in requests:
scheduler.add_request(request)
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
# simulate two FSM requests at the waiting head that become ready now.
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
req_fsm_1.structured_output_request = Mock(grammar=object())
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
req_fsm_2.structured_output_request = Mock(grammar=object())
# simulate a remote-KV request that is ready to be promoted now.
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
scheduler.skipped_waiting.add_request(req_fsm_1)
scheduler.skipped_waiting.add_request(req_fsm_2)
scheduler.skipped_waiting.add_request(req_remote)
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
scheduler._update_waiting_for_remote_kv = Mock()
output = scheduler.schedule()
assert output.scheduled_new_reqs
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
waiting_req_ids = [
req.request_id
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
]
assert waiting_req_ids == [
req_fsm_2.request_id,
req_remote.request_id,
req_tail.request_id,
]
def test_fcfs_mixed_skipped_waiting_types_keep_order():
scheduler = create_scheduler(max_num_batched_tokens=20)
scheduler._update_waiting_for_remote_kv = Mock()
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
)[0]
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
req_fsm.status = RequestStatus.WAITING_FOR_FSM
req_fsm.structured_output_request = Mock(grammar=None)
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
scheduler.add_request(req)
scheduler.schedule()
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
assert not scheduler.running
req_fsm.structured_output_request = Mock(grammar=object())
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
req_stream.status = RequestStatus.WAITING
second_output = scheduler.schedule()
expected_order = [
req_fsm.request_id,
req_remote.request_id,
req_stream.request_id,
req_tail.request_id,
]
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
assert [req.request_id for req in scheduler.running] == expected_order
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
def test_abort_request_waiting_for_remote_kvs():

View File

@@ -119,7 +119,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
scheduler_output = fail_scheduler.schedule()
assert len(fail_scheduler.waiting) == 1
assert len(fail_scheduler.skipped_waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == num_external_computed_tokens
@@ -145,3 +145,4 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
assert output.finish_reason == FinishReason.ERROR
assert len(fail_scheduler.waiting) == 0
assert len(fail_scheduler.skipped_waiting) == 0

View File

@@ -337,7 +337,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
scheduler_output = recompute_scheduler.schedule()
# request should be waiting for remote KVs
assert len(recompute_scheduler.waiting) == 1
assert len(recompute_scheduler.skipped_waiting) == 1
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == num_external_computed_tokens

View File

@@ -76,8 +76,9 @@ def test_async_load_failure(
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 3
for request in scheduler.skipped_waiting:
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
@@ -96,8 +97,9 @@ def test_async_load_failure(
min_invalid_block_idx = min(invalid_block_idxs)
assert len(scheduler.waiting) == 3
for request in scheduler.waiting:
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 3
for request in scheduler.skipped_waiting:
if request.request_id == request2.request_id:
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
@@ -303,8 +305,9 @@ def test_async_progressive_load_failure(
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 1
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == num_external_computed_tokens
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
@@ -325,8 +328,9 @@ def test_async_progressive_load_failure(
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
assert len(scheduler.waiting) == 1
assert scheduler.waiting.peek_request().request_id == request.request_id
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 1
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
assert request.num_computed_tokens == (
min_invalid_block_idx * scheduler.block_size
)

View File

@@ -18,6 +18,10 @@ from .utils import (
pytestmark = pytest.mark.cpu_test
def _num_waiting_requests(scheduler) -> int:
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
@@ -54,8 +58,8 @@ def test_basic_lifecycle():
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert _num_waiting_requests(scheduler) == 1
assert request in scheduler.skipped_waiting
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
assert request.num_computed_tokens == NUM_TOKENS
@@ -81,7 +85,7 @@ def test_basic_lifecycle():
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
@@ -94,7 +98,7 @@ def test_basic_lifecycle():
engine_core_outputs = scheduler.update_from_output(
scheduler_output, model_runner_output
)
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert request_id in scheduler.finished_recving_kv_req_ids
# STEP (3):
@@ -180,7 +184,7 @@ def test_interleaved_lifecycle():
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
@@ -190,7 +194,7 @@ def test_interleaved_lifecycle():
# STEP 3: continue running, KVs not arrived yet.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -199,14 +203,14 @@ def test_interleaved_lifecycle():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
# STEP 4: KVs arrive.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -218,7 +222,7 @@ def test_interleaved_lifecycle():
# STEP 5: RECVed KVs are sent to ModelRunner.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
assert len(scheduler_output.scheduled_new_reqs) == 1
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
@@ -279,14 +283,14 @@ def test_no_spurious_prefix_caching():
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Schedule the local prefill request. This should
# cause blocks to be cached, but separately from
scheduler.add_request(request_local)
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0
@@ -348,7 +352,7 @@ def test_full_block_prompt():
finished_recving={request_id}
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert request_id in scheduler.finished_recving_kv_req_ids
# # STEP (3): Run as usual.
@@ -418,7 +422,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
# Step 2: 5 blocks are in use (2 new for remote blocks).
scheduler.add_request(request_remote)
@@ -426,7 +430,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 3: finish recving (5 blocks in use)
scheduler_output = scheduler.schedule()
@@ -435,7 +439,7 @@ def test_cannot_schedule_after_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 4: try to schedule, remote request is put to running list
# because the transfer is completed.
@@ -445,7 +449,7 @@ def test_cannot_schedule_after_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 2
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
# Step 5: Remote request will be put back to waiting list
# because it needs new block to hold generated token.
@@ -453,7 +457,7 @@ def test_cannot_schedule_after_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 6: finish the request, free it.
scheduler_output = scheduler.schedule()
@@ -462,7 +466,7 @@ def test_cannot_schedule_after_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 7: now we can schedule (with 2 blocks computed),
# request is retrieved from preempted list.
@@ -474,7 +478,7 @@ def test_cannot_schedule_after_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
# Step 8: free everything.
scheduler_output = scheduler.schedule()
@@ -521,7 +525,7 @@ def test_cannot_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
# Step 2: 3 blocks are in use,
# need 3 new for remote blocks but only 2 are available.
@@ -530,7 +534,7 @@ def test_cannot_recv():
model_runner_output = create_model_runner_output(reqs=[request_normal])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Should not have KV transfer in progress.
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
@@ -541,14 +545,14 @@ def test_cannot_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
# Step 5: finish recving (5 blocks in use)
@@ -558,14 +562,14 @@ def test_cannot_recv():
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert _num_waiting_requests(scheduler) == 1
# Step 6: schedule remote request
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.running) == 1
assert len(scheduler.waiting) == 0
assert _num_waiting_requests(scheduler) == 0
# Step 7: free everything.
scheduler_output = scheduler.schedule()

View File

@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.request_queue import (
RequestQueue,
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
) from e
# Priority queues for requests.
self.waiting = create_request_queue(self.policy)
# requests skipped in waiting flow due async deps or constraints.
self.skipped_waiting = create_request_queue(self.policy)
self.running: list[Request] = []
# The request IDs that are finished in between the previous and the
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
step_skipped_waiting = create_request_queue(self.policy)
while self.waiting and token_budget > 0:
while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
request_queue = self._select_waiting_queue_for_scheduling()
assert request_queue is not None
request = request_queue.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
# try to promote blocked statuses while traversing skipped queue.
if self._is_blocked_waiting_status(
request.status
) and not self._try_promote_blocked_waiting_request(request):
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
num_external_computed_tokens = 0
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
request.num_external_computed_tokens = ext_tokens
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
request = request_queue.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
step_skipped_waiting.prepend_request(request)
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# re-queue requests skipped in this pass ahead of older skipped items.
if step_skipped_waiting:
self.skipped_waiting.prepend_requests(step_skipped_waiting)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
@staticmethod
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
return status in (
RequestStatus.WAITING_FOR_FSM,
RequestStatus.WAITING_FOR_REMOTE_KVS,
RequestStatus.WAITING_FOR_STREAMING_REQ,
)
def _enqueue_waiting_request(self, request: Request) -> None:
if self._is_blocked_waiting_status(request.status):
self.skipped_waiting.add_request(request)
else:
self.waiting.add_request(request)
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
if self.policy == SchedulingPolicy.FCFS:
return self.skipped_waiting or self.waiting or None
# PRIORITY mode: compare queue heads when both queues are non-empty.
if self.waiting and self.skipped_waiting:
waiting_req = self.waiting.peek_request()
skipped_req = self.skipped_waiting.peek_request()
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
return self.waiting or self.skipped_waiting or None
def _handle_stopped_request(self, request: Request) -> bool:
"""Return True if finished (can be False for resumable requests)."""
if not request.resumable:
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
self.num_waiting_for_streaming_input += 1
self.waiting.add_request(request)
self._enqueue_waiting_request(request)
return False
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
return len(self.running), len(self.waiting) + len(self.skipped_waiting)
def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
else:
if request.resumable:
request.streaming_queue = deque()
self.waiting.add_request(request)
self._enqueue_waiting_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
self.running = remove_all(self.running, running_requests_to_remove)
if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove)
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
# Second pass: set status and free requests
for request in valid_requests:
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
return 0
if self._pause_state == PauseState.PAUSED_NEW:
return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
num_waiting = (
len(self.waiting)
+ len(self.skipped_waiting)
- self.num_waiting_for_streaming_input
)
return num_waiting + len(self.running)
def has_finished_requests(self) -> bool:
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
)
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
kv_cache_usage=self.kv_cache_manager.usage,
encoder_cache_usage=self._get_encoder_cache_usage(),
prefix_cache_stats=prefix_cache_stats,
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
return self.connector.request_finished_all_groups(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
def _update_waiting_for_remote_kv(self, request: Request) -> None:
"""
KV Connector: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
KV Connector: update request state after async recv is finished.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
"""
Try to promote a blocked waiting request back to schedulable states.
"""
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# finished_recving_kv_req_ids is populated during
# update_from_output(), based on worker-side connector signals
# in KVConnectorOutput.finished_recving
if request.request_id not in self.finished_recving_kv_req_ids:
return False
self._update_waiting_for_remote_kv(request)
if request.num_preemptions:
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if not (structured_output_req and structured_output_req.grammar):
return False
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
return False
raise AssertionError(
"Unexpected blocked waiting status in promotion: "
f"{request.status.name} for request {request.request_id}"
)
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
"""
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
# handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = (
req
for req in self.waiting
for req in self.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
)
async_failed_req_ids, num_failed_tokens, _ = (