[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -1115,12 +1115,16 @@ def _step_until_done(
|
||||
all_finished = all_done
|
||||
|
||||
|
||||
def _num_waiting_requests(scheduler: Scheduler) -> int:
|
||||
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
|
||||
|
||||
|
||||
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
"""Cycle requests through a KV transfer cycle."""
|
||||
|
||||
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == len(req_ids)
|
||||
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
for req in scheduler.requests.values():
|
||||
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||
|
||||
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
||||
output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == len(req_ids)
|
||||
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
||||
# All can be scheduled - 1st token.
|
||||
output = scheduler.schedule()
|
||||
if is_async:
|
||||
assert len(scheduler.waiting) == 2
|
||||
assert _num_waiting_requests(scheduler) == 2
|
||||
assert scheduler.running == []
|
||||
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||
output = scheduler.schedule()
|
||||
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
||||
# This will have a local and remote cache hit.
|
||||
output = scheduler.schedule()
|
||||
if is_async:
|
||||
waiting_req_ids = [req.request_id for req in scheduler.waiting]
|
||||
waiting_req_ids = [
|
||||
req.request_id
|
||||
for req in scheduler.skipped_waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
]
|
||||
assert len(waiting_req_ids) == 1
|
||||
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
|
||||
output = scheduler.schedule()
|
||||
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.skipped_waiting) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
|
||||
# simulate first 2 waiting requests are waiting for remote KVs
|
||||
for req in expected_waiting_reqs[:2]:
|
||||
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
|
||||
for req in expected_waiting_reqs[:2]:
|
||||
scheduler.skipped_waiting.add_request(req)
|
||||
|
||||
# schedule step
|
||||
# expect the first 2 waiting to be skipped, the third running,
|
||||
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
|
||||
expected_waiting_reqs.pop(2)
|
||||
|
||||
# verify waiting order is preserved
|
||||
assert list(scheduler.waiting) == expected_waiting_reqs
|
||||
waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||
assert waiting_reqs == expected_waiting_reqs
|
||||
|
||||
|
||||
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
|
||||
scheduler = create_scheduler(max_num_seqs=1)
|
||||
scheduler.connector = Mock()
|
||||
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
|
||||
|
||||
requests = create_requests(num_requests=4)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
|
||||
|
||||
# simulate two FSM requests at the waiting head that become ready now.
|
||||
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm_1.structured_output_request = Mock(grammar=object())
|
||||
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm_2.structured_output_request = Mock(grammar=object())
|
||||
|
||||
# simulate a remote-KV request that is ready to be promoted now.
|
||||
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
|
||||
scheduler.skipped_waiting.add_request(req_fsm_1)
|
||||
scheduler.skipped_waiting.add_request(req_fsm_2)
|
||||
scheduler.skipped_waiting.add_request(req_remote)
|
||||
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||
scheduler._update_waiting_for_remote_kv = Mock()
|
||||
|
||||
output = scheduler.schedule()
|
||||
|
||||
assert output.scheduled_new_reqs
|
||||
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
|
||||
waiting_req_ids = [
|
||||
req.request_id
|
||||
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||
]
|
||||
assert waiting_req_ids == [
|
||||
req_fsm_2.request_id,
|
||||
req_remote.request_id,
|
||||
req_tail.request_id,
|
||||
]
|
||||
|
||||
|
||||
def test_fcfs_mixed_skipped_waiting_types_keep_order():
|
||||
scheduler = create_scheduler(max_num_batched_tokens=20)
|
||||
scheduler._update_waiting_for_remote_kv = Mock()
|
||||
|
||||
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
|
||||
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
|
||||
)[0]
|
||||
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
|
||||
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
|
||||
req_fsm.status = RequestStatus.WAITING_FOR_FSM
|
||||
req_fsm.structured_output_request = Mock(grammar=None)
|
||||
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||
|
||||
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
|
||||
scheduler.add_request(req)
|
||||
scheduler.schedule()
|
||||
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
|
||||
|
||||
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
|
||||
assert not scheduler.running
|
||||
|
||||
req_fsm.structured_output_request = Mock(grammar=object())
|
||||
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||
req_stream.status = RequestStatus.WAITING
|
||||
|
||||
second_output = scheduler.schedule()
|
||||
expected_order = [
|
||||
req_fsm.request_id,
|
||||
req_remote.request_id,
|
||||
req_stream.request_id,
|
||||
req_tail.request_id,
|
||||
]
|
||||
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
|
||||
assert [req.request_id for req in scheduler.running] == expected_order
|
||||
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
|
||||
|
||||
|
||||
def test_abort_request_waiting_for_remote_kvs():
|
||||
|
||||
Reference in New Issue
Block a user