[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -1115,12 +1115,16 @@ def _step_until_done(
|
||||
all_finished = all_done
|
||||
|
||||
|
||||
def _num_waiting_requests(scheduler: Scheduler) -> int:
|
||||
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
|
||||
|
||||
|
||||
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
"""Cycle requests through a KV transfer cycle."""
|
||||
|
||||
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == len(req_ids)
|
||||
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
for req in scheduler.requests.values():
|
||||
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
|
||||
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == len(req_ids)
|
||||
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
||||
# All can be scheduled - 1st token.
|
||||
output = scheduler.schedule()
|
||||
if is_async:
|
||||
assert len(scheduler.waiting) == 2
|
||||
assert _num_waiting_requests(scheduler) == 2
|
||||
assert scheduler.running == []
|
||||
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||
output = scheduler.schedule()
|
||||
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
||||
# This will have a local and remote cache hit.
|
||||
output = scheduler.schedule()
|
||||
if is_async:
|
||||
waiting_req_ids = [req.request_id for req in scheduler.waiting]
|
||||
waiting_req_ids = [
|
||||
req.request_id
|
||||
for req in scheduler.skipped_waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
]
|
||||
assert len(waiting_req_ids) == 1
|
||||
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
|
||||
output = scheduler.schedule()
|
||||
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
|
||||
# simulate first 2 waiting requests are waiting for remote KVs
|
||||
for req in expected_waiting_reqs[:2]:
|
||||
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
|
||||
for req in expected_waiting_reqs[:2]:
|
||||
scheduler.skipped_waiting.add_request(req)
|
||||
|
||||
# schedule step
|
||||
# expect the first 2 waiting to be skipped, the third running,
|
||||
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
|
||||
expected_waiting_reqs.pop(2)
|
||||
|
||||
# verify waiting order is preserved
|
||||
assert list(scheduler.waiting) == expected_waiting_reqs
|
||||
waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||
assert waiting_reqs == expected_waiting_reqs
|
||||
|
||||
|
||||
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
|
||||
scheduler = create_scheduler(max_num_seqs=1)
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
|
||||
|
||||
requests = create_requests(num_requests=4)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
|
||||
|
||||
# simulate two FSM requests at the waiting head that become ready now.
|
||||
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm_1.structured_output_request = Mock(grammar=object())
|
||||
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm_2.structured_output_request = Mock(grammar=object())
|
||||
|
||||
# simulate a remote-KV request that is ready to be promoted now.
|
||||
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
|
||||
scheduler.skipped_waiting.add_request(req_fsm_1)
|
||||
scheduler.skipped_waiting.add_request(req_fsm_2)
|
||||
scheduler.skipped_waiting.add_request(req_remote)
|
||||
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||
scheduler._update_waiting_for_remote_kv = Mock()
|
||||
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert output.scheduled_new_reqs
|
||||
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
|
||||
waiting_req_ids = [
|
||||
req.request_id
|
||||
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||
]
|
||||
assert waiting_req_ids == [
|
||||
req_fsm_2.request_id,
|
||||
req_remote.request_id,
|
||||
req_tail.request_id,
|
||||
]
|
||||
|
||||
|
||||
def test_fcfs_mixed_skipped_waiting_types_keep_order():
|
||||
scheduler = create_scheduler(max_num_batched_tokens=20)
|
||||
scheduler._update_waiting_for_remote_kv = Mock()
|
||||
|
||||
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
|
||||
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
|
||||
)[0]
|
||||
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
|
||||
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
|
||||
req_fsm.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm.structured_output_request = Mock(grammar=None)
|
||||
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||
|
||||
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
|
||||
scheduler.add_request(req)
|
||||
scheduler.schedule()
|
||||
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
|
||||
|
||||
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
assert not scheduler.running
|
||||
|
||||
req_fsm.structured_output_request = Mock(grammar=object())
|
||||
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||
req_stream.status = RequestStatus.WAITING
|
||||
|
||||
second_output = scheduler.schedule()
|
||||
expected_order = [
|
||||
req_fsm.request_id,
|
||||
req_remote.request_id,
|
||||
req_stream.request_id,
|
||||
req_tail.request_id,
|
||||
]
|
||||
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
|
||||
assert [req.request_id for req in scheduler.running] == expected_order
|
||||
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
|
||||
|
||||
|
||||
def test_abort_request_waiting_for_remote_kvs():
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
|
||||
scheduler_output = fail_scheduler.schedule()
|
||||
|
||||
assert len(fail_scheduler.waiting) == 1
|
||||
assert len(fail_scheduler.skipped_waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == num_external_computed_tokens
|
||||
|
||||
@@ -145,3 +145,4 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
||||
assert output.finish_reason == FinishReason.ERROR
|
||||
|
||||
assert len(fail_scheduler.waiting) == 0
|
||||
assert len(fail_scheduler.skipped_waiting) == 0
|
||||
|
||||
@@ -337,7 +337,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
|
||||
scheduler_output = recompute_scheduler.schedule()
|
||||
|
||||
# request should be waiting for remote KVs
|
||||
assert len(recompute_scheduler.waiting) == 1
|
||||
assert len(recompute_scheduler.skipped_waiting) == 1
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == num_external_computed_tokens
|
||||
|
||||
|
||||
@@ -76,8 +76,9 @@ def test_async_load_failure(
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 3
|
||||
for request in scheduler.skipped_waiting:
|
||||
assert request.num_computed_tokens == num_external_computed_tokens
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||
@@ -96,8 +97,9 @@ def test_async_load_failure(
|
||||
|
||||
min_invalid_block_idx = min(invalid_block_idxs)
|
||||
|
||||
assert len(scheduler.waiting) == 3
|
||||
for request in scheduler.waiting:
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 3
|
||||
for request in scheduler.skipped_waiting:
|
||||
if request.request_id == request2.request_id:
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
@@ -303,8 +305,9 @@ def test_async_progressive_load_failure(
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 1
|
||||
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == num_external_computed_tokens
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||
@@ -325,8 +328,9 @@ def test_async_progressive_load_failure(
|
||||
|
||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 1
|
||||
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
|
||||
assert request.num_computed_tokens == (
|
||||
min_invalid_block_idx * scheduler.block_size
|
||||
)
|
||||
|
||||
@@ -18,6 +18,10 @@ from .utils import (
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _num_waiting_requests(scheduler) -> int:
|
||||
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
@@ -54,8 +58,8 @@ def test_basic_lifecycle():
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert request in scheduler.skipped_waiting
|
||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
assert request.num_computed_tokens == NUM_TOKENS
|
||||
|
||||
@@ -81,7 +85,7 @@ def test_basic_lifecycle():
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
@@ -94,7 +98,7 @@ def test_basic_lifecycle():
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
scheduler_output, model_runner_output
|
||||
)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# STEP (3):
|
||||
@@ -180,7 +184,7 @@ def test_interleaved_lifecycle():
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||
|
||||
@@ -190,7 +194,7 @@ def test_interleaved_lifecycle():
|
||||
# STEP 3: continue running, KVs not arrived yet.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
@@ -199,14 +203,14 @@ def test_interleaved_lifecycle():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
# STEP 4: KVs arrive.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
@@ -218,7 +222,7 @@ def test_interleaved_lifecycle():
|
||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||
|
||||
@@ -279,14 +283,14 @@ def test_no_spurious_prefix_caching():
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Schedule the local prefill request. This should
|
||||
# cause blocks to be cached, but separately from
|
||||
scheduler.add_request(request_local)
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0
|
||||
@@ -348,7 +352,7 @@ def test_full_block_prompt():
|
||||
finished_recving={request_id}
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
@@ -418,7 +422,7 @@ def test_cannot_schedule_after_recv():
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
|
||||
# Step 2: 5 blocks are in use (2 new for remote blocks).
|
||||
scheduler.add_request(request_remote)
|
||||
@@ -426,7 +430,7 @@ def test_cannot_schedule_after_recv():
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 3: finish recving (5 blocks in use)
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -435,7 +439,7 @@ def test_cannot_schedule_after_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 4: try to schedule, remote request is put to running list
|
||||
# because the transfer is completed.
|
||||
@@ -445,7 +449,7 @@ def test_cannot_schedule_after_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 2
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
|
||||
# Step 5: Remote request will be put back to waiting list
|
||||
# because it needs new block to hold generated token.
|
||||
@@ -453,7 +457,7 @@ def test_cannot_schedule_after_recv():
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 6: finish the request, free it.
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -462,7 +466,7 @@ def test_cannot_schedule_after_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 7: now we can schedule (with 2 blocks computed),
|
||||
# request is retrieved from preempted list.
|
||||
@@ -474,7 +478,7 @@ def test_cannot_schedule_after_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
|
||||
# Step 8: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
@@ -521,7 +525,7 @@ def test_cannot_recv():
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
|
||||
# Step 2: 3 blocks are in use,
|
||||
# need 3 new for remote blocks but only 2 are available.
|
||||
@@ -530,7 +534,7 @@ def test_cannot_recv():
|
||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
# Should not have KV transfer in progress.
|
||||
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
@@ -541,14 +545,14 @@ def test_cannot_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
|
||||
# Step 5: finish recving (5 blocks in use)
|
||||
@@ -558,14 +562,14 @@ def test_cannot_recv():
|
||||
)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert _num_waiting_requests(scheduler) == 1
|
||||
|
||||
# Step 6: schedule remote request
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert _num_waiting_requests(scheduler) == 0
|
||||
|
||||
# Step 7: free everything.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.request_queue import (
|
||||
RequestQueue,
|
||||
SchedulingPolicy,
|
||||
create_request_queue,
|
||||
)
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
|
||||
) from e
|
||||
# Priority queues for requests.
|
||||
self.waiting = create_request_queue(self.policy)
|
||||
# requests skipped in waiting flow due async deps or constraints.
|
||||
self.skipped_waiting = create_request_queue(self.policy)
|
||||
self.running: list[Request] = []
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
step_skipped_waiting = create_request_queue(self.policy)
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
while (self.waiting or self.skipped_waiting) and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_queue = self._select_waiting_queue_for_scheduling()
|
||||
assert request_queue is not None
|
||||
|
||||
request = request_queue.peek_request()
|
||||
request_id = request.request_id
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
if request.num_preemptions:
|
||||
# We must be loading for a resumed preemption
|
||||
# rather than a new request.
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
# try to promote blocked statuses while traversing skipped queue.
|
||||
if self._is_blocked_waiting_status(
|
||||
request.status
|
||||
) and not self._try_promote_blocked_waiting_request(request):
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request_id,
|
||||
)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
# for FSM compilation.
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Streaming: skip request if still waiting for next streaming req.
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
|
||||
# The request cannot be scheduled because
|
||||
# the KVConnector couldn't determine
|
||||
# the number of matched tokens.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
request = self.waiting.pop_request()
|
||||
request = request_queue.pop_request()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
# Set num_computed_tokens even though KVs are not yet loaded.
|
||||
# request.num_computed_tokens will not be used anywhere until
|
||||
# the request finished the KV transfer.
|
||||
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
# re-queue requests skipped in this pass ahead of older skipped items.
|
||||
if step_skipped_waiting:
|
||||
self.skipped_waiting.prepend_requests(step_skipped_waiting)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
@staticmethod
|
||||
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
|
||||
return status in (
|
||||
RequestStatus.WAITING_FOR_FSM,
|
||||
RequestStatus.WAITING_FOR_REMOTE_KVS,
|
||||
RequestStatus.WAITING_FOR_STREAMING_REQ,
|
||||
)
|
||||
|
||||
def _enqueue_waiting_request(self, request: Request) -> None:
|
||||
if self._is_blocked_waiting_status(request.status):
|
||||
self.skipped_waiting.add_request(request)
|
||||
else:
|
||||
self.waiting.add_request(request)
|
||||
|
||||
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
|
||||
if self.policy == SchedulingPolicy.FCFS:
|
||||
return self.skipped_waiting or self.waiting or None
|
||||
|
||||
# PRIORITY mode: compare queue heads when both queues are non-empty.
|
||||
if self.waiting and self.skipped_waiting:
|
||||
waiting_req = self.waiting.peek_request()
|
||||
skipped_req = self.skipped_waiting.peek_request()
|
||||
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
|
||||
|
||||
return self.waiting or self.skipped_waiting or None
|
||||
|
||||
def _handle_stopped_request(self, request: Request) -> bool:
|
||||
"""Return True if finished (can be False for resumable requests)."""
|
||||
if not request.resumable:
|
||||
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
|
||||
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||
self.num_waiting_for_streaming_input += 1
|
||||
|
||||
self.waiting.add_request(request)
|
||||
self._enqueue_waiting_request(request)
|
||||
return False
|
||||
|
||||
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
|
||||
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
return len(self.running), len(self.waiting)
|
||||
return len(self.running), len(self.waiting) + len(self.skipped_waiting)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
existing = self.requests.get(request.request_id)
|
||||
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
if request.resumable:
|
||||
request.streaming_queue = deque()
|
||||
self.waiting.add_request(request)
|
||||
self._enqueue_waiting_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.running = remove_all(self.running, running_requests_to_remove)
|
||||
if waiting_requests_to_remove:
|
||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
|
||||
|
||||
# Second pass: set status and free requests
|
||||
for request in valid_requests:
|
||||
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
|
||||
return 0
|
||||
if self._pause_state == PauseState.PAUSED_NEW:
|
||||
return len(self.running)
|
||||
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
|
||||
num_waiting = (
|
||||
len(self.waiting)
|
||||
+ len(self.skipped_waiting)
|
||||
- self.num_waiting_for_streaming_input
|
||||
)
|
||||
return num_waiting + len(self.running)
|
||||
|
||||
def has_finished_requests(self) -> bool:
|
||||
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
encoder_cache_usage=self._get_encoder_cache_usage(),
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return self.connector.request_finished_all_groups(request, block_ids)
|
||||
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> None:
|
||||
"""
|
||||
KV Connector: check if the request_id is finished_recving.
|
||||
|
||||
The finished_recving_kv_req_ids list is populated
|
||||
on the previous steps()'s update_from_output based
|
||||
on the worker side connector.
|
||||
KV Connector: update request state after async recv is finished.
|
||||
|
||||
When the kv transfer is ready, we cache the blocks
|
||||
and the request state will be moved back to WAITING from
|
||||
WAITING_FOR_REMOTE_KV.
|
||||
"""
|
||||
assert self.connector is not None
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
if request.request_id in self.failed_recving_kv_req_ids:
|
||||
# Request had KV load failures; num_computed_tokens was already
|
||||
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = request.num_computed_tokens
|
||||
|
||||
# Return that we are ready.
|
||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||
return True
|
||||
|
||||
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
|
||||
"""
|
||||
Try to promote a blocked waiting request back to schedulable states.
|
||||
"""
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
# finished_recving_kv_req_ids is populated during
|
||||
# update_from_output(), based on worker-side connector signals
|
||||
# in KVConnectorOutput.finished_recving
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
self._update_waiting_for_remote_kv(request)
|
||||
if request.num_preemptions:
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
return True
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if not (structured_output_req and structured_output_req.grammar):
|
||||
return False
|
||||
request.status = RequestStatus.WAITING
|
||||
return True
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
return False
|
||||
|
||||
raise AssertionError(
|
||||
"Unexpected blocked waiting status in promotion: "
|
||||
f"{request.status.name} for request {request.request_id}"
|
||||
)
|
||||
|
||||
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
|
||||
"""
|
||||
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
|
||||
# handle async KV loads (not cached yet, evict_blocks=False)
|
||||
async_load_reqs = (
|
||||
req
|
||||
for req in self.waiting
|
||||
for req in self.skipped_waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
)
|
||||
async_failed_req_ids, num_failed_tokens, _ = (
|
||||
|
||||
Reference in New Issue
Block a user