[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -1115,12 +1115,16 @@ def _step_until_done(
|
|||||||
all_finished = all_done
|
all_finished = all_done
|
||||||
|
|
||||||
|
|
||||||
|
def _num_waiting_requests(scheduler: Scheduler) -> int:
|
||||||
|
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
|
||||||
|
|
||||||
|
|
||||||
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
||||||
"""Cycle requests through a KV transfer cycle."""
|
"""Cycle requests through a KV transfer cycle."""
|
||||||
|
|
||||||
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
# Requests should first transition to WAITING_FOR_REMOTE_KVS
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(scheduler.waiting) == len(req_ids)
|
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(output.scheduled_new_reqs) == 0
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
for req in scheduler.requests.values():
|
for req in scheduler.requests.values():
|
||||||
@@ -1139,7 +1143,7 @@ def _step_until_kv_transfer_finished(scheduler: Scheduler, req_ids: list[str]):
|
|||||||
|
|
||||||
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
# Simulate KV transfer completion using KVConnectorOutput.finished_recving
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(scheduler.waiting) == len(req_ids)
|
assert _num_waiting_requests(scheduler) == len(req_ids)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
|
|
||||||
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
|
||||||
@@ -1546,7 +1550,7 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
|||||||
# All can be scheduled - 1st token.
|
# All can be scheduled - 1st token.
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
if is_async:
|
if is_async:
|
||||||
assert len(scheduler.waiting) == 2
|
assert _num_waiting_requests(scheduler) == 2
|
||||||
assert scheduler.running == []
|
assert scheduler.running == []
|
||||||
_step_until_kv_transfer_finished(scheduler, req_ids)
|
_step_until_kv_transfer_finished(scheduler, req_ids)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@@ -1604,7 +1608,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
|
|||||||
# This will have a local and remote cache hit.
|
# This will have a local and remote cache hit.
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
if is_async:
|
if is_async:
|
||||||
waiting_req_ids = [req.request_id for req in scheduler.waiting]
|
waiting_req_ids = [
|
||||||
|
req.request_id
|
||||||
|
for req in scheduler.skipped_waiting
|
||||||
|
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
]
|
||||||
assert len(waiting_req_ids) == 1
|
assert len(waiting_req_ids) == 1
|
||||||
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
|
_step_until_kv_transfer_finished(scheduler, waiting_req_ids)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@@ -2439,7 +2447,8 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
|||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == 0
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.skipped_waiting) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -3626,6 +3635,9 @@ def test_prepend_skipped_requests_order():
|
|||||||
# simulate first 2 waiting requests are waiting for remote KVs
|
# simulate first 2 waiting requests are waiting for remote KVs
|
||||||
for req in expected_waiting_reqs[:2]:
|
for req in expected_waiting_reqs[:2]:
|
||||||
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
req.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
scheduler.waiting.remove_requests(expected_waiting_reqs[:2])
|
||||||
|
for req in expected_waiting_reqs[:2]:
|
||||||
|
scheduler.skipped_waiting.add_request(req)
|
||||||
|
|
||||||
# schedule step
|
# schedule step
|
||||||
# expect the first 2 waiting to be skipped, the third running,
|
# expect the first 2 waiting to be skipped, the third running,
|
||||||
@@ -3636,7 +3648,87 @@ def test_prepend_skipped_requests_order():
|
|||||||
expected_waiting_reqs.pop(2)
|
expected_waiting_reqs.pop(2)
|
||||||
|
|
||||||
# verify waiting order is preserved
|
# verify waiting order is preserved
|
||||||
assert list(scheduler.waiting) == expected_waiting_reqs
|
waiting_reqs = list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||||
|
assert waiting_reqs == expected_waiting_reqs
|
||||||
|
|
||||||
|
|
||||||
|
def test_remote_kv_promotion_keeps_fcfs_with_fsm_prefix():
|
||||||
|
scheduler = create_scheduler(max_num_seqs=1)
|
||||||
|
scheduler.connector = Mock()
|
||||||
|
scheduler.connector.get_num_new_matched_tokens.return_value = (0, False)
|
||||||
|
|
||||||
|
requests = create_requests(num_requests=4)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
req_fsm_1, req_fsm_2, req_remote, req_tail = list(scheduler.waiting)
|
||||||
|
|
||||||
|
# simulate two FSM requests at the waiting head that become ready now.
|
||||||
|
req_fsm_1.status = RequestStatus.WAITING_FOR_FSM
|
||||||
|
req_fsm_1.structured_output_request = Mock(grammar=object())
|
||||||
|
req_fsm_2.status = RequestStatus.WAITING_FOR_FSM
|
||||||
|
req_fsm_2.structured_output_request = Mock(grammar=object())
|
||||||
|
|
||||||
|
# simulate a remote-KV request that is ready to be promoted now.
|
||||||
|
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
scheduler.waiting.remove_requests([req_fsm_1, req_fsm_2, req_remote])
|
||||||
|
scheduler.skipped_waiting.add_request(req_fsm_1)
|
||||||
|
scheduler.skipped_waiting.add_request(req_fsm_2)
|
||||||
|
scheduler.skipped_waiting.add_request(req_remote)
|
||||||
|
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||||
|
scheduler._update_waiting_for_remote_kv = Mock()
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
|
||||||
|
assert output.scheduled_new_reqs
|
||||||
|
assert output.scheduled_new_reqs[0].req_id == req_fsm_1.request_id
|
||||||
|
waiting_req_ids = [
|
||||||
|
req.request_id
|
||||||
|
for req in list(scheduler.skipped_waiting) + list(scheduler.waiting)
|
||||||
|
]
|
||||||
|
assert waiting_req_ids == [
|
||||||
|
req_fsm_2.request_id,
|
||||||
|
req_remote.request_id,
|
||||||
|
req_tail.request_id,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fcfs_mixed_skipped_waiting_types_keep_order():
|
||||||
|
scheduler = create_scheduler(max_num_batched_tokens=20)
|
||||||
|
scheduler._update_waiting_for_remote_kv = Mock()
|
||||||
|
|
||||||
|
mk_req = lambda req_id, num_tokens=1: create_requests( # noqa: E731
|
||||||
|
num_requests=1, num_tokens=num_tokens, req_ids=[req_id]
|
||||||
|
)[0]
|
||||||
|
req_fsm, req_remote, req_stream = mk_req("fsm"), mk_req("remote"), mk_req("stream")
|
||||||
|
req_regular, req_tail = mk_req("regular", 20), mk_req("tail")
|
||||||
|
req_fsm.status = RequestStatus.WAITING_FOR_FSM
|
||||||
|
req_fsm.structured_output_request = Mock(grammar=None)
|
||||||
|
req_remote.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
req_stream.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||||
|
|
||||||
|
for req in (req_fsm, req_remote, req_stream, req_regular, req_tail):
|
||||||
|
scheduler.add_request(req)
|
||||||
|
scheduler.schedule()
|
||||||
|
assert list(scheduler.skipped_waiting) == [req_fsm, req_remote, req_stream]
|
||||||
|
|
||||||
|
scheduler.finish_requests(req_regular.request_id, RequestStatus.FINISHED_ABORTED)
|
||||||
|
assert not scheduler.running
|
||||||
|
|
||||||
|
req_fsm.structured_output_request = Mock(grammar=object())
|
||||||
|
scheduler.finished_recving_kv_req_ids.add(req_remote.request_id)
|
||||||
|
req_stream.status = RequestStatus.WAITING
|
||||||
|
|
||||||
|
second_output = scheduler.schedule()
|
||||||
|
expected_order = [
|
||||||
|
req_fsm.request_id,
|
||||||
|
req_remote.request_id,
|
||||||
|
req_stream.request_id,
|
||||||
|
req_tail.request_id,
|
||||||
|
]
|
||||||
|
assert [req.req_id for req in second_output.scheduled_new_reqs] == expected_order
|
||||||
|
assert [req.request_id for req in scheduler.running] == expected_order
|
||||||
|
scheduler._update_waiting_for_remote_kv.assert_called_once_with(req_remote)
|
||||||
|
|
||||||
|
|
||||||
def test_abort_request_waiting_for_remote_kvs():
|
def test_abort_request_waiting_for_remote_kvs():
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
|||||||
|
|
||||||
scheduler_output = fail_scheduler.schedule()
|
scheduler_output = fail_scheduler.schedule()
|
||||||
|
|
||||||
assert len(fail_scheduler.waiting) == 1
|
assert len(fail_scheduler.skipped_waiting) == 1
|
||||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
assert request.num_computed_tokens == num_external_computed_tokens
|
assert request.num_computed_tokens == num_external_computed_tokens
|
||||||
|
|
||||||
@@ -145,3 +145,4 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler):
|
|||||||
assert output.finish_reason == FinishReason.ERROR
|
assert output.finish_reason == FinishReason.ERROR
|
||||||
|
|
||||||
assert len(fail_scheduler.waiting) == 0
|
assert len(fail_scheduler.waiting) == 0
|
||||||
|
assert len(fail_scheduler.skipped_waiting) == 0
|
||||||
|
|||||||
@@ -337,7 +337,7 @@ def test_async_recompute_blocks_not_cached_when_invalid(
|
|||||||
scheduler_output = recompute_scheduler.schedule()
|
scheduler_output = recompute_scheduler.schedule()
|
||||||
|
|
||||||
# request should be waiting for remote KVs
|
# request should be waiting for remote KVs
|
||||||
assert len(recompute_scheduler.waiting) == 1
|
assert len(recompute_scheduler.skipped_waiting) == 1
|
||||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
assert request.num_computed_tokens == num_external_computed_tokens
|
assert request.num_computed_tokens == num_external_computed_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -76,8 +76,9 @@ def test_async_load_failure(
|
|||||||
|
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
assert len(scheduler.waiting) == 3
|
assert len(scheduler.waiting) == 0
|
||||||
for request in scheduler.waiting:
|
assert len(scheduler.skipped_waiting) == 3
|
||||||
|
for request in scheduler.skipped_waiting:
|
||||||
assert request.num_computed_tokens == num_external_computed_tokens
|
assert request.num_computed_tokens == num_external_computed_tokens
|
||||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
assert scheduler.connector.get_num_new_matched_tokens.call_count == 3
|
||||||
@@ -96,8 +97,9 @@ def test_async_load_failure(
|
|||||||
|
|
||||||
min_invalid_block_idx = min(invalid_block_idxs)
|
min_invalid_block_idx = min(invalid_block_idxs)
|
||||||
|
|
||||||
assert len(scheduler.waiting) == 3
|
assert len(scheduler.waiting) == 0
|
||||||
for request in scheduler.waiting:
|
assert len(scheduler.skipped_waiting) == 3
|
||||||
|
for request in scheduler.skipped_waiting:
|
||||||
if request.request_id == request2.request_id:
|
if request.request_id == request2.request_id:
|
||||||
assert request.num_computed_tokens == (
|
assert request.num_computed_tokens == (
|
||||||
min_invalid_block_idx * scheduler.block_size
|
min_invalid_block_idx * scheduler.block_size
|
||||||
@@ -303,8 +305,9 @@ def test_async_progressive_load_failure(
|
|||||||
|
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 0
|
||||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
assert len(scheduler.skipped_waiting) == 1
|
||||||
|
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
|
||||||
assert request.num_computed_tokens == num_external_computed_tokens
|
assert request.num_computed_tokens == num_external_computed_tokens
|
||||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
assert scheduler.connector.get_num_new_matched_tokens.call_count == 1
|
||||||
@@ -325,8 +328,9 @@ def test_async_progressive_load_failure(
|
|||||||
|
|
||||||
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
min_invalid_block_idx = min(min_invalid_block_idx, invalid_block_idx)
|
||||||
|
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 0
|
||||||
assert scheduler.waiting.peek_request().request_id == request.request_id
|
assert len(scheduler.skipped_waiting) == 1
|
||||||
|
assert scheduler.skipped_waiting.peek_request().request_id == request.request_id
|
||||||
assert request.num_computed_tokens == (
|
assert request.num_computed_tokens == (
|
||||||
min_invalid_block_idx * scheduler.block_size
|
min_invalid_block_idx * scheduler.block_size
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ from .utils import (
|
|||||||
pytestmark = pytest.mark.cpu_test
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def _num_waiting_requests(scheduler) -> int:
|
||||||
|
return len(scheduler.waiting) + len(scheduler.skipped_waiting)
|
||||||
|
|
||||||
|
|
||||||
def test_basic_lifecycle():
|
def test_basic_lifecycle():
|
||||||
"""Test lifecycle of a remote prefill."""
|
"""Test lifecycle of a remote prefill."""
|
||||||
|
|
||||||
@@ -54,8 +58,8 @@ def test_basic_lifecycle():
|
|||||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||||
|
|
||||||
# Req waiting for KVs with no computed/scheduled toks ...
|
# Req waiting for KVs with no computed/scheduled toks ...
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert request in scheduler.waiting
|
assert request in scheduler.skipped_waiting
|
||||||
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
assert request.num_computed_tokens == NUM_TOKENS
|
assert request.num_computed_tokens == NUM_TOKENS
|
||||||
|
|
||||||
@@ -81,7 +85,7 @@ def test_basic_lifecycle():
|
|||||||
# STEP (2):
|
# STEP (2):
|
||||||
# (2a): schedule(): nothing happens!
|
# (2a): schedule(): nothing happens!
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
|
|
||||||
# (2b): forward(): request finishes recv.
|
# (2b): forward(): request finishes recv.
|
||||||
@@ -94,7 +98,7 @@ def test_basic_lifecycle():
|
|||||||
engine_core_outputs = scheduler.update_from_output(
|
engine_core_outputs = scheduler.update_from_output(
|
||||||
scheduler_output, model_runner_output
|
scheduler_output, model_runner_output
|
||||||
)
|
)
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||||
|
|
||||||
# STEP (3):
|
# STEP (3):
|
||||||
@@ -180,7 +184,7 @@ def test_interleaved_lifecycle():
|
|||||||
scheduler.add_request(request_remote)
|
scheduler.add_request(request_remote)
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 1
|
||||||
|
|
||||||
@@ -190,7 +194,7 @@ def test_interleaved_lifecycle():
|
|||||||
# STEP 3: continue running, KVs not arrived yet.
|
# STEP 3: continue running, KVs not arrived yet.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
@@ -199,14 +203,14 @@ def test_interleaved_lifecycle():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
# STEP 4: KVs arrive.
|
# STEP 4: KVs arrive.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
@@ -218,7 +222,7 @@ def test_interleaved_lifecycle():
|
|||||||
# STEP 5: RECVed KVs are sent to ModelRunner.
|
# STEP 5: RECVed KVs are sent to ModelRunner.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 2
|
||||||
|
|
||||||
@@ -279,14 +283,14 @@ def test_no_spurious_prefix_caching():
|
|||||||
scheduler.add_request(request_remote)
|
scheduler.add_request(request_remote)
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Schedule the local prefill request. This should
|
# Schedule the local prefill request. This should
|
||||||
# cause blocks to be cached, but separately from
|
# cause blocks to be cached, but separately from
|
||||||
scheduler.add_request(request_local)
|
scheduler.add_request(request_local)
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
local_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||||
0
|
0
|
||||||
@@ -348,7 +352,7 @@ def test_full_block_prompt():
|
|||||||
finished_recving={request_id}
|
finished_recving={request_id}
|
||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert request_id in scheduler.finished_recving_kv_req_ids
|
assert request_id in scheduler.finished_recving_kv_req_ids
|
||||||
|
|
||||||
# # STEP (3): Run as usual.
|
# # STEP (3): Run as usual.
|
||||||
@@ -418,7 +422,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
|
|
||||||
# Step 2: 5 blocks are in use (2 new for remote blocks).
|
# Step 2: 5 blocks are in use (2 new for remote blocks).
|
||||||
scheduler.add_request(request_remote)
|
scheduler.add_request(request_remote)
|
||||||
@@ -426,7 +430,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 3: finish recving (5 blocks in use)
|
# Step 3: finish recving (5 blocks in use)
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
@@ -435,7 +439,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 4: try to schedule, remote request is put to running list
|
# Step 4: try to schedule, remote request is put to running list
|
||||||
# because the transfer is completed.
|
# because the transfer is completed.
|
||||||
@@ -445,7 +449,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 2
|
assert len(scheduler.running) == 2
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
|
|
||||||
# Step 5: Remote request will be put back to waiting list
|
# Step 5: Remote request will be put back to waiting list
|
||||||
# because it needs new block to hold generated token.
|
# because it needs new block to hold generated token.
|
||||||
@@ -453,7 +457,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 6: finish the request, free it.
|
# Step 6: finish the request, free it.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
@@ -462,7 +466,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 7: now we can schedule (with 2 blocks computed),
|
# Step 7: now we can schedule (with 2 blocks computed),
|
||||||
# request is retrieved from preempted list.
|
# request is retrieved from preempted list.
|
||||||
@@ -474,7 +478,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
|
|
||||||
# Step 8: free everything.
|
# Step 8: free everything.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
@@ -521,7 +525,7 @@ def test_cannot_recv():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
|
|
||||||
# Step 2: 3 blocks are in use,
|
# Step 2: 3 blocks are in use,
|
||||||
# need 3 new for remote blocks but only 2 are available.
|
# need 3 new for remote blocks but only 2 are available.
|
||||||
@@ -530,7 +534,7 @@ def test_cannot_recv():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
# Should not have KV transfer in progress.
|
# Should not have KV transfer in progress.
|
||||||
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
|
||||||
@@ -541,14 +545,14 @@ def test_cannot_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[])
|
model_runner_output = create_model_runner_output(reqs=[])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
assert request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
|
||||||
# Step 5: finish recving (5 blocks in use)
|
# Step 5: finish recving (5 blocks in use)
|
||||||
@@ -558,14 +562,14 @@ def test_cannot_recv():
|
|||||||
)
|
)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert _num_waiting_requests(scheduler) == 1
|
||||||
|
|
||||||
# Step 6: schedule remote request
|
# Step 6: schedule remote request
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 0
|
assert _num_waiting_requests(scheduler) == 0
|
||||||
|
|
||||||
# Step 7: free everything.
|
# Step 7: free everything.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
|
|||||||
@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
|
|||||||
NewRequestData,
|
NewRequestData,
|
||||||
SchedulerOutput,
|
SchedulerOutput,
|
||||||
)
|
)
|
||||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
from vllm.v1.core.sched.request_queue import (
|
||||||
|
RequestQueue,
|
||||||
|
SchedulingPolicy,
|
||||||
|
create_request_queue,
|
||||||
|
)
|
||||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
) from e
|
) from e
|
||||||
# Priority queues for requests.
|
# Priority queues for requests.
|
||||||
self.waiting = create_request_queue(self.policy)
|
self.waiting = create_request_queue(self.policy)
|
||||||
|
# requests skipped in waiting flow due async deps or constraints.
|
||||||
|
self.skipped_waiting = create_request_queue(self.policy)
|
||||||
self.running: list[Request] = []
|
self.running: list[Request] = []
|
||||||
|
|
||||||
# The request IDs that are finished in between the previous and the
|
# The request IDs that are finished in between the previous and the
|
||||||
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
# Next, schedule the WAITING requests.
|
# Next, schedule the WAITING requests.
|
||||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||||
# Use a temporary RequestQueue to collect requests that need to be
|
step_skipped_waiting = create_request_queue(self.policy)
|
||||||
# skipped and put back at the head of the waiting queue later
|
|
||||||
skipped_waiting_requests = create_request_queue(self.policy)
|
|
||||||
|
|
||||||
while self.waiting and token_budget > 0:
|
while (self.waiting or self.skipped_waiting) and token_budget > 0:
|
||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == self.max_num_running_reqs:
|
||||||
break
|
break
|
||||||
|
|
||||||
request = self.waiting.peek_request()
|
request_queue = self._select_waiting_queue_for_scheduling()
|
||||||
|
assert request_queue is not None
|
||||||
|
|
||||||
|
request = request_queue.peek_request()
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
|
|
||||||
# KVTransfer: skip request if still waiting for remote kvs.
|
# try to promote blocked statuses while traversing skipped queue.
|
||||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
if self._is_blocked_waiting_status(
|
||||||
is_ready = self._update_waiting_for_remote_kv(request)
|
request.status
|
||||||
if is_ready:
|
) and not self._try_promote_blocked_waiting_request(request):
|
||||||
if request.num_preemptions:
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
# We must be loading for a resumed preemption
|
|
||||||
# rather than a new request.
|
|
||||||
request.status = RequestStatus.PREEMPTED
|
|
||||||
else:
|
|
||||||
request.status = RequestStatus.WAITING
|
|
||||||
else:
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||||
request_id,
|
request_id,
|
||||||
)
|
)
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip request if the structured output request is still waiting
|
|
||||||
# for FSM compilation.
|
|
||||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
|
||||||
structured_output_req = request.structured_output_request
|
|
||||||
if structured_output_req and structured_output_req.grammar:
|
|
||||||
request.status = RequestStatus.WAITING
|
|
||||||
else:
|
|
||||||
self.waiting.pop_request()
|
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Streaming: skip request if still waiting for next streaming req.
|
|
||||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
|
||||||
assert not request.streaming_queue
|
|
||||||
self.waiting.pop_request()
|
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
# Scheduling would exceed max_loras, skip.
|
# Scheduling would exceed max_loras, skip.
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_external_computed_tokens = 0
|
num_external_computed_tokens = 0
|
||||||
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
# The request cannot be scheduled because
|
# The request cannot be scheduled because
|
||||||
# the KVConnector couldn't determine
|
# the KVConnector couldn't determine
|
||||||
# the number of matched tokens.
|
# the number of matched tokens.
|
||||||
self.waiting.pop_request()
|
request_queue.pop_request()
|
||||||
skipped_waiting_requests.prepend_request(request)
|
step_skipped_waiting.prepend_request(request)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
request.num_external_computed_tokens = ext_tokens
|
request.num_external_computed_tokens = ext_tokens
|
||||||
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
|
|||||||
preempted=request.num_preemptions > 0,
|
preempted=request.num_preemptions > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request was already popped from self.waiting
|
request = request_queue.pop_request()
|
||||||
# unless it was re-added above due to new_blocks being None.
|
|
||||||
request = self.waiting.pop_request()
|
|
||||||
if load_kv_async:
|
if load_kv_async:
|
||||||
# If loading async, allocate memory and put request
|
# If loading async, allocate memory and put request
|
||||||
# into the WAITING_FOR_REMOTE_KV state.
|
# into the WAITING_FOR_REMOTE_KV state.
|
||||||
skipped_waiting_requests.prepend_request(request)
|
|
||||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
step_skipped_waiting.prepend_request(request)
|
||||||
# Set num_computed_tokens even though KVs are not yet loaded.
|
# Set num_computed_tokens even though KVs are not yet loaded.
|
||||||
# request.num_computed_tokens will not be used anywhere until
|
# request.num_computed_tokens will not be used anywhere until
|
||||||
# the request finished the KV transfer.
|
# the request finished the KV transfer.
|
||||||
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.ec_connector is not None:
|
if self.ec_connector is not None:
|
||||||
self.ec_connector.update_state_after_alloc(request, i)
|
self.ec_connector.update_state_after_alloc(request, i)
|
||||||
|
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
# re-queue requests skipped in this pass ahead of older skipped items.
|
||||||
if skipped_waiting_requests:
|
if step_skipped_waiting:
|
||||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
self.skipped_waiting.prepend_requests(step_skipped_waiting)
|
||||||
|
|
||||||
# Check if the scheduling constraints are satisfied.
|
# Check if the scheduling constraints are satisfied.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
|
||||||
|
return status in (
|
||||||
|
RequestStatus.WAITING_FOR_FSM,
|
||||||
|
RequestStatus.WAITING_FOR_REMOTE_KVS,
|
||||||
|
RequestStatus.WAITING_FOR_STREAMING_REQ,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _enqueue_waiting_request(self, request: Request) -> None:
|
||||||
|
if self._is_blocked_waiting_status(request.status):
|
||||||
|
self.skipped_waiting.add_request(request)
|
||||||
|
else:
|
||||||
|
self.waiting.add_request(request)
|
||||||
|
|
||||||
|
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
|
||||||
|
if self.policy == SchedulingPolicy.FCFS:
|
||||||
|
return self.skipped_waiting or self.waiting or None
|
||||||
|
|
||||||
|
# PRIORITY mode: compare queue heads when both queues are non-empty.
|
||||||
|
if self.waiting and self.skipped_waiting:
|
||||||
|
waiting_req = self.waiting.peek_request()
|
||||||
|
skipped_req = self.skipped_waiting.peek_request()
|
||||||
|
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
|
||||||
|
|
||||||
|
return self.waiting or self.skipped_waiting or None
|
||||||
|
|
||||||
def _handle_stopped_request(self, request: Request) -> bool:
|
def _handle_stopped_request(self, request: Request) -> bool:
|
||||||
"""Return True if finished (can be False for resumable requests)."""
|
"""Return True if finished (can be False for resumable requests)."""
|
||||||
if not request.resumable:
|
if not request.resumable:
|
||||||
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||||
self.num_waiting_for_streaming_input += 1
|
self.num_waiting_for_streaming_input += 1
|
||||||
|
|
||||||
self.waiting.add_request(request)
|
self._enqueue_waiting_request(request)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
|
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
|
||||||
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
def get_request_counts(self) -> tuple[int, int]:
|
def get_request_counts(self) -> tuple[int, int]:
|
||||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||||
return len(self.running), len(self.waiting)
|
return len(self.running), len(self.waiting) + len(self.skipped_waiting)
|
||||||
|
|
||||||
def add_request(self, request: Request) -> None:
|
def add_request(self, request: Request) -> None:
|
||||||
existing = self.requests.get(request.request_id)
|
existing = self.requests.get(request.request_id)
|
||||||
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
else:
|
else:
|
||||||
if request.resumable:
|
if request.resumable:
|
||||||
request.streaming_queue = deque()
|
request.streaming_queue = deque()
|
||||||
self.waiting.add_request(request)
|
self._enqueue_waiting_request(request)
|
||||||
self.requests[request.request_id] = request
|
self.requests[request.request_id] = request
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
request.record_event(EngineCoreEventType.QUEUED)
|
request.record_event(EngineCoreEventType.QUEUED)
|
||||||
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.running = remove_all(self.running, running_requests_to_remove)
|
self.running = remove_all(self.running, running_requests_to_remove)
|
||||||
if waiting_requests_to_remove:
|
if waiting_requests_to_remove:
|
||||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||||
|
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
|
||||||
|
|
||||||
# Second pass: set status and free requests
|
# Second pass: set status and free requests
|
||||||
for request in valid_requests:
|
for request in valid_requests:
|
||||||
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
return 0
|
return 0
|
||||||
if self._pause_state == PauseState.PAUSED_NEW:
|
if self._pause_state == PauseState.PAUSED_NEW:
|
||||||
return len(self.running)
|
return len(self.running)
|
||||||
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
|
num_waiting = (
|
||||||
|
len(self.waiting)
|
||||||
|
+ len(self.skipped_waiting)
|
||||||
|
- self.num_waiting_for_streaming_input
|
||||||
|
)
|
||||||
return num_waiting + len(self.running)
|
return num_waiting + len(self.running)
|
||||||
|
|
||||||
def has_finished_requests(self) -> bool:
|
def has_finished_requests(self) -> bool:
|
||||||
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
)
|
)
|
||||||
return SchedulerStats(
|
return SchedulerStats(
|
||||||
num_running_reqs=len(self.running),
|
num_running_reqs=len(self.running),
|
||||||
num_waiting_reqs=len(self.waiting),
|
num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
|
||||||
kv_cache_usage=self.kv_cache_manager.usage,
|
kv_cache_usage=self.kv_cache_manager.usage,
|
||||||
encoder_cache_usage=self._get_encoder_cache_usage(),
|
encoder_cache_usage=self._get_encoder_cache_usage(),
|
||||||
prefix_cache_stats=prefix_cache_stats,
|
prefix_cache_stats=prefix_cache_stats,
|
||||||
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
|
|||||||
|
|
||||||
return self.connector.request_finished_all_groups(request, block_ids)
|
return self.connector.request_finished_all_groups(request, block_ids)
|
||||||
|
|
||||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
def _update_waiting_for_remote_kv(self, request: Request) -> None:
|
||||||
"""
|
"""
|
||||||
KV Connector: check if the request_id is finished_recving.
|
KV Connector: update request state after async recv is finished.
|
||||||
|
|
||||||
The finished_recving_kv_req_ids list is populated
|
|
||||||
on the previous steps()'s update_from_output based
|
|
||||||
on the worker side connector.
|
|
||||||
|
|
||||||
When the kv transfer is ready, we cache the blocks
|
When the kv transfer is ready, we cache the blocks
|
||||||
and the request state will be moved back to WAITING from
|
and the request state will be moved back to WAITING from
|
||||||
WAITING_FOR_REMOTE_KV.
|
WAITING_FOR_REMOTE_KV.
|
||||||
"""
|
"""
|
||||||
assert self.connector is not None
|
assert self.connector is not None
|
||||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if request.request_id in self.failed_recving_kv_req_ids:
|
if request.request_id in self.failed_recving_kv_req_ids:
|
||||||
# Request had KV load failures; num_computed_tokens was already
|
# Request had KV load failures; num_computed_tokens was already
|
||||||
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
|
|||||||
if request.num_cached_tokens < 0:
|
if request.num_cached_tokens < 0:
|
||||||
request.num_cached_tokens = request.num_computed_tokens
|
request.num_cached_tokens = request.num_computed_tokens
|
||||||
|
|
||||||
# Return that we are ready.
|
|
||||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||||
return True
|
|
||||||
|
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
|
||||||
|
"""
|
||||||
|
Try to promote a blocked waiting request back to schedulable states.
|
||||||
|
"""
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
|
# finished_recving_kv_req_ids is populated during
|
||||||
|
# update_from_output(), based on worker-side connector signals
|
||||||
|
# in KVConnectorOutput.finished_recving
|
||||||
|
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||||
|
return False
|
||||||
|
self._update_waiting_for_remote_kv(request)
|
||||||
|
if request.num_preemptions:
|
||||||
|
request.status = RequestStatus.PREEMPTED
|
||||||
|
else:
|
||||||
|
request.status = RequestStatus.WAITING
|
||||||
|
return True
|
||||||
|
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||||
|
structured_output_req = request.structured_output_request
|
||||||
|
if not (structured_output_req and structured_output_req.grammar):
|
||||||
|
return False
|
||||||
|
request.status = RequestStatus.WAITING
|
||||||
|
return True
|
||||||
|
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||||
|
assert not request.streaming_queue
|
||||||
|
return False
|
||||||
|
|
||||||
|
raise AssertionError(
|
||||||
|
"Unexpected blocked waiting status in promotion: "
|
||||||
|
f"{request.status.name} for request {request.request_id}"
|
||||||
|
)
|
||||||
|
|
||||||
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
|
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
|
||||||
"""
|
"""
|
||||||
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
# handle async KV loads (not cached yet, evict_blocks=False)
|
# handle async KV loads (not cached yet, evict_blocks=False)
|
||||||
async_load_reqs = (
|
async_load_reqs = (
|
||||||
req
|
req
|
||||||
for req in self.waiting
|
for req in self.skipped_waiting
|
||||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
)
|
)
|
||||||
async_failed_req_ids, num_failed_tokens, _ = (
|
async_failed_req_ids, num_failed_tokens, _ = (
|
||||||
|
|||||||
Reference in New Issue
Block a user