[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Wentao Ye
2026-03-11 00:25:30 -04:00
committed by GitHub
parent 42fadebecb
commit a8ff2cca92
6 changed files with 243 additions and 105 deletions

View File

@@ -1115,12 +1115,16 @@ def _step_until_done(
all_finished = all_done
def _num_waiting_requests(scheduler: Scheduler) -> int:
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
"""Cycle requests through a KV transfer cycle."""
# Requests should first transition to WAITING_FOR_REMOTE_KVS
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0
assert len(output.scheduled_new_reqs) == 0
for req in scheduler.requests.values():
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
output = scheduler.schedule()
assert len(scheduler.waiting) == len(req_ids)
assert _num_waiting_requests(scheduler) == len(req_ids)
assert len(scheduler.running) == 0
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# All can be scheduled - 1st token.
output = scheduler.schedule()
if is_async:
assert len(scheduler.waiting) == 2
assert _num_waiting_requests(scheduler) == 2
assert scheduler.running == []
_step_until_kv_transfer_finished(scheduler, req_ids)
output = scheduler.schedule()
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
# This will have a local and remote cache hit.
output = scheduler.schedule()
if is_async:
waiting_req_ids = [req.request_id for req in scheduler.waiting]
waiting_req_ids = [
req.request_id
for req in scheduler.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
]
assert len(waiting_req_ids) == 1
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
output = scheduler.schedule()
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1
assert len(scheduler.waiting) == 0
assert len(scheduler.skipped_waiting) == 1
@pytest.mark.parametrize(
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
# simulate first 2 waiting requests are waiting for remote KVs
for req in expected_waiting_reqs[:2]:
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
for req in expected_waiting_reqs[:2]:
scheduler.skipped_waiting.add_request(req)
# schedule step
# expect the first 2 waiting to be skipped, the third running,
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
expected_waiting_reqs.pop(2)
# verify waiting order is preserved
assert list(scheduler.waiting) == expected_waiting_reqs
waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
assert waiting_reqs == expected_waiting_reqs
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
scheduler = create_scheduler(max_num_seqs=1)
scheduler.connector = Mock()
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
requests = create_requests(num_requests=4)
for request in requests:
scheduler.add_request(request)
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
# simulate two FSM requests at the waiting head that become ready now.
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
req_fsm_1.structured_output_request = Mock(grammar=object())
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
req_fsm_2.structured_output_request = Mock(grammar=object())
# simulate a remote-KV request that is ready to be promoted now.
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
scheduler.skipped_waiting.add_request(req_fsm_1)
scheduler.skipped_waiting.add_request(req_fsm_2)
scheduler.skipped_waiting.add_request(req_remote)
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
scheduler._update_waiting_for_remote_kv = Mock()
output = scheduler.schedule()
assert output.scheduled_new_reqs
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
waiting_req_ids = [
req.request_id
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
]
assert waiting_req_ids == [
req_fsm_2.request_id,
req_remote.request_id,
req_tail.request_id,
]
def test_fcfs_mixed_skipped_waiting_types_keep_order():
scheduler = create_scheduler(max_num_batched_tokens=20)
scheduler._update_waiting_for_remote_kv = Mock()
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
)[0]
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
req_fsm.status = RequestStatus.WAITING_FOR_FSM
req_fsm.structured_output_request = Mock(grammar=None)
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
scheduler.add_request(req)
scheduler.schedule()
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
assert not scheduler.running
req_fsm.structured_output_request = Mock(grammar=object())
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
req_stream.status = RequestStatus.WAITING
second_output = scheduler.schedule()
expected_order = [
req_fsm.request_id,
req_remote.request_id,
req_stream.request_id,
req_tail.request_id,
]
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
assert [req.request_id for req in scheduler.running] == expected_order
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
def test_abort_request_waiting_for_remote_kvs():