[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Wentao Ye
2026-03-11 00:25:30 -04:00
committed by GitHub
parent 42fadebecb
commit a8ff2cca92
6 changed files with 243 additions and 105 deletions

View File

@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.request_queue import (
RequestQueue,
SchedulingPolicy,
create_request_queue,
)
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
) from e
# Priority queues for requests.
self.waiting = create_request_queue(self.policy)
# requests skipped in waiting flow due async deps or constraints.
self.skipped_waiting = create_request_queue(self.policy)
self.running: list[Request] = []
# The request IDs that are finished in between the previous and the
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests.
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
step_skipped_waiting = create_request_queue(self.policy)
while self.waiting and token_budget > 0:
while (self.waiting or self.skipped_waiting) and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
request_queue = self._select_waiting_queue_for_scheduling()
assert request_queue is not None
request = request_queue.peek_request()
request_id = request.request_id
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
if request.num_preemptions:
# We must be loading for a resumed preemption
# rather than a new request.
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
else:
# try to promote blocked statuses while traversing skipped queue.
if self._is_blocked_waiting_status(
request.status
) and not self._try_promote_blocked_waiting_request(request):
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request_id,
)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Streaming: skip request if still waiting for next streaming req.
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
)
):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
num_external_computed_tokens = 0
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
request_queue.pop_request()
step_skipped_waiting.prepend_request(request)
continue
request.num_external_computed_tokens = ext_tokens
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
preempted=request.num_preemptions > 0,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
request = request_queue.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
step_skipped_waiting.prepend_request(request)
# Set num_computed_tokens even though KVs are not yet loaded.
# request.num_computed_tokens will not be used anywhere until
# the request finished the KV transfer.
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# re-queue requests skipped in this pass ahead of older skipped items.
if step_skipped_waiting:
self.skipped_waiting.prepend_requests(step_skipped_waiting)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
@staticmethod
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
return status in (
RequestStatus.WAITING_FOR_FSM,
RequestStatus.WAITING_FOR_REMOTE_KVS,
RequestStatus.WAITING_FOR_STREAMING_REQ,
)
def _enqueue_waiting_request(self, request: Request) -> None:
if self._is_blocked_waiting_status(request.status):
self.skipped_waiting.add_request(request)
else:
self.waiting.add_request(request)
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
if self.policy == SchedulingPolicy.FCFS:
return self.skipped_waiting or self.waiting or None
# PRIORITY mode: compare queue heads when both queues are non-empty.
if self.waiting and self.skipped_waiting:
waiting_req = self.waiting.peek_request()
skipped_req = self.skipped_waiting.peek_request()
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
return self.waiting or self.skipped_waiting or None
def _handle_stopped_request(self, request: Request) -> bool:
"""Return True if finished (can be False for resumable requests)."""
if not request.resumable:
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
self.num_waiting_for_streaming_input += 1
self.waiting.add_request(request)
self._enqueue_waiting_request(request)
return False
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)
return len(self.running), len(self.waiting) + len(self.skipped_waiting)
def add_request(self, request: Request) -> None:
existing = self.requests.get(request.request_id)
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
else:
if request.resumable:
request.streaming_queue = deque()
self.waiting.add_request(request)
self._enqueue_waiting_request(request)
self.requests[request.request_id] = request
if self.log_stats:
request.record_event(EngineCoreEventType.QUEUED)
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
self.running = remove_all(self.running, running_requests_to_remove)
if waiting_requests_to_remove:
self.waiting.remove_requests(waiting_requests_to_remove)
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
# Second pass: set status and free requests
for request in valid_requests:
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
return 0
if self._pause_state == PauseState.PAUSED_NEW:
return len(self.running)
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
num_waiting = (
len(self.waiting)
+ len(self.skipped_waiting)
- self.num_waiting_for_streaming_input
)
return num_waiting + len(self.running)
def has_finished_requests(self) -> bool:
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
)
return SchedulerStats(
num_running_reqs=len(self.running),
num_waiting_reqs=len(self.waiting),
num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
kv_cache_usage=self.kv_cache_manager.usage,
encoder_cache_usage=self._get_encoder_cache_usage(),
prefix_cache_stats=prefix_cache_stats,
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
return self.connector.request_finished_all_groups(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
def _update_waiting_for_remote_kv(self, request: Request) -> None:
"""
KV Connector: check if the request_id is finished_recving.
The finished_recving_kv_req_ids list is populated
on the previous steps()'s update_from_output based
on the worker side connector.
KV Connector: update request state after async recv is finished.
When the kv transfer is ready, we cache the blocks
and the request state will be moved back to WAITING from
WAITING_FOR_REMOTE_KV.
"""
assert self.connector is not None
if request.request_id not in self.finished_recving_kv_req_ids:
return False
if request.request_id in self.failed_recving_kv_req_ids:
# Request had KV load failures; num_computed_tokens was already
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
if request.num_cached_tokens < 0:
request.num_cached_tokens = request.num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
return True
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
"""
Try to promote a blocked waiting request back to schedulable states.
"""
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
# finished_recving_kv_req_ids is populated during
# update_from_output(), based on worker-side connector signals
# in KVConnectorOutput.finished_recving
if request.request_id not in self.finished_recving_kv_req_ids:
return False
self._update_waiting_for_remote_kv(request)
if request.num_preemptions:
request.status = RequestStatus.PREEMPTED
else:
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if not (structured_output_req and structured_output_req.grammar):
return False
request.status = RequestStatus.WAITING
return True
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
assert not request.streaming_queue
return False
raise AssertionError(
"Unexpected blocked waiting status in promotion: "
f"{request.status.name} for request {request.request_id}"
)
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
"""
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
# handle async KV loads (not cached yet, evict_blocks=False)
async_load_reqs = (
req
for req in self.waiting
for req in self.skipped_waiting
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
)
async_failed_req_ids, num_failed_tokens, _ = (