[Perf] Optimize scheduler overhead for PD disaggregation, around 5% E2E perf improvement (#35781)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -45,7 +45,11 @@ from vllm.v1.core.sched.output import (
|
||||
NewRequestData,
|
||||
SchedulerOutput,
|
||||
)
|
||||
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
|
||||
from vllm.v1.core.sched.request_queue import (
|
||||
RequestQueue,
|
||||
SchedulingPolicy,
|
||||
create_request_queue,
|
||||
)
|
||||
from vllm.v1.core.sched.utils import check_stop, remove_all
|
||||
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -160,6 +164,8 @@ class Scheduler(SchedulerInterface):
|
||||
) from e
|
||||
# Priority queues for requests.
|
||||
self.waiting = create_request_queue(self.policy)
|
||||
# requests skipped in waiting flow due async deps or constraints.
|
||||
self.skipped_waiting = create_request_queue(self.policy)
|
||||
self.running: list[Request] = []
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
@@ -531,52 +537,29 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs and self._pause_state == PauseState.UNPAUSED:
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
step_skipped_waiting = create_request_queue(self.policy)
|
||||
|
||||
while self.waiting and token_budget > 0:
|
||||
while (self.waiting or self.skipped_waiting) and token_budget > 0:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting.peek_request()
|
||||
request_queue = self._select_waiting_queue_for_scheduling()
|
||||
assert request_queue is not None
|
||||
|
||||
request = request_queue.peek_request()
|
||||
request_id = request.request_id
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
is_ready = self._update_waiting_for_remote_kv(request)
|
||||
if is_ready:
|
||||
if request.num_preemptions:
|
||||
# We must be loading for a resumed preemption
|
||||
# rather than a new request.
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
# try to promote blocked statuses while traversing skipped queue.
|
||||
if self._is_blocked_waiting_status(
|
||||
request.status
|
||||
) and not self._try_promote_blocked_waiting_request(request):
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request_id,
|
||||
)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
# for FSM compilation.
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Streaming: skip request if still waiting for next streaming req.
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
@@ -590,8 +573,8 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
@@ -617,8 +600,8 @@ class Scheduler(SchedulerInterface):
|
||||
# The request cannot be scheduled because
|
||||
# the KVConnector couldn't determine
|
||||
# the number of matched tokens.
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request_queue.pop_request()
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
continue
|
||||
|
||||
request.num_external_computed_tokens = ext_tokens
|
||||
@@ -761,14 +744,12 @@ class Scheduler(SchedulerInterface):
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
request = self.waiting.pop_request()
|
||||
request = request_queue.pop_request()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
step_skipped_waiting.prepend_request(request)
|
||||
# Set num_computed_tokens even though KVs are not yet loaded.
|
||||
# request.num_computed_tokens will not be used anywhere until
|
||||
# the request finished the KV transfer.
|
||||
@@ -825,9 +806,9 @@ class Scheduler(SchedulerInterface):
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
# re-queue requests skipped in this pass ahead of older skipped items.
|
||||
if step_skipped_waiting:
|
||||
self.skipped_waiting.prepend_requests(step_skipped_waiting)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
@@ -1531,6 +1512,32 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
@staticmethod
|
||||
def _is_blocked_waiting_status(status: RequestStatus) -> bool:
|
||||
return status in (
|
||||
RequestStatus.WAITING_FOR_FSM,
|
||||
RequestStatus.WAITING_FOR_REMOTE_KVS,
|
||||
RequestStatus.WAITING_FOR_STREAMING_REQ,
|
||||
)
|
||||
|
||||
def _enqueue_waiting_request(self, request: Request) -> None:
|
||||
if self._is_blocked_waiting_status(request.status):
|
||||
self.skipped_waiting.add_request(request)
|
||||
else:
|
||||
self.waiting.add_request(request)
|
||||
|
||||
def _select_waiting_queue_for_scheduling(self) -> RequestQueue | None:
|
||||
if self.policy == SchedulingPolicy.FCFS:
|
||||
return self.skipped_waiting or self.waiting or None
|
||||
|
||||
# PRIORITY mode: compare queue heads when both queues are non-empty.
|
||||
if self.waiting and self.skipped_waiting:
|
||||
waiting_req = self.waiting.peek_request()
|
||||
skipped_req = self.skipped_waiting.peek_request()
|
||||
return self.waiting if waiting_req < skipped_req else self.skipped_waiting
|
||||
|
||||
return self.waiting or self.skipped_waiting or None
|
||||
|
||||
def _handle_stopped_request(self, request: Request) -> bool:
|
||||
"""Return True if finished (can be False for resumable requests)."""
|
||||
if not request.resumable:
|
||||
@@ -1546,7 +1553,7 @@ class Scheduler(SchedulerInterface):
|
||||
request.status = RequestStatus.WAITING_FOR_STREAMING_REQ
|
||||
self.num_waiting_for_streaming_input += 1
|
||||
|
||||
self.waiting.add_request(request)
|
||||
self._enqueue_waiting_request(request)
|
||||
return False
|
||||
|
||||
def _get_routed_experts(self, request: Request) -> np.ndarray | None:
|
||||
@@ -1677,7 +1684,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
return len(self.running), len(self.waiting)
|
||||
return len(self.running), len(self.waiting) + len(self.skipped_waiting)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
existing = self.requests.get(request.request_id)
|
||||
@@ -1696,7 +1703,7 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
if request.resumable:
|
||||
request.streaming_queue = deque()
|
||||
self.waiting.add_request(request)
|
||||
self._enqueue_waiting_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
@@ -1747,6 +1754,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.running = remove_all(self.running, running_requests_to_remove)
|
||||
if waiting_requests_to_remove:
|
||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||
self.skipped_waiting.remove_requests(waiting_requests_to_remove)
|
||||
|
||||
# Second pass: set status and free requests
|
||||
for request in valid_requests:
|
||||
@@ -1798,7 +1806,11 @@ class Scheduler(SchedulerInterface):
|
||||
return 0
|
||||
if self._pause_state == PauseState.PAUSED_NEW:
|
||||
return len(self.running)
|
||||
num_waiting = len(self.waiting) - self.num_waiting_for_streaming_input
|
||||
num_waiting = (
|
||||
len(self.waiting)
|
||||
+ len(self.skipped_waiting)
|
||||
- self.num_waiting_for_streaming_input
|
||||
)
|
||||
return num_waiting + len(self.running)
|
||||
|
||||
def has_finished_requests(self) -> bool:
|
||||
@@ -1898,7 +1910,7 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
num_waiting_reqs=len(self.waiting) + len(self.skipped_waiting),
|
||||
kv_cache_usage=self.kv_cache_manager.usage,
|
||||
encoder_cache_usage=self._get_encoder_cache_usage(),
|
||||
prefix_cache_stats=prefix_cache_stats,
|
||||
@@ -1981,21 +1993,15 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
return self.connector.request_finished_all_groups(request, block_ids)
|
||||
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> None:
|
||||
"""
|
||||
KV Connector: check if the request_id is finished_recving.
|
||||
|
||||
The finished_recving_kv_req_ids list is populated
|
||||
on the previous steps()'s update_from_output based
|
||||
on the worker side connector.
|
||||
KV Connector: update request state after async recv is finished.
|
||||
|
||||
When the kv transfer is ready, we cache the blocks
|
||||
and the request state will be moved back to WAITING from
|
||||
WAITING_FOR_REMOTE_KV.
|
||||
"""
|
||||
assert self.connector is not None
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
if request.request_id in self.failed_recving_kv_req_ids:
|
||||
# Request had KV load failures; num_computed_tokens was already
|
||||
@@ -2023,9 +2029,40 @@ class Scheduler(SchedulerInterface):
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = request.num_computed_tokens
|
||||
|
||||
# Return that we are ready.
|
||||
self.finished_recving_kv_req_ids.remove(request.request_id)
|
||||
return True
|
||||
|
||||
def _try_promote_blocked_waiting_request(self, request: Request) -> bool:
|
||||
"""
|
||||
Try to promote a blocked waiting request back to schedulable states.
|
||||
"""
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
# finished_recving_kv_req_ids is populated during
|
||||
# update_from_output(), based on worker-side connector signals
|
||||
# in KVConnectorOutput.finished_recving
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
self._update_waiting_for_remote_kv(request)
|
||||
if request.num_preemptions:
|
||||
request.status = RequestStatus.PREEMPTED
|
||||
else:
|
||||
request.status = RequestStatus.WAITING
|
||||
return True
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if not (structured_output_req and structured_output_req.grammar):
|
||||
return False
|
||||
request.status = RequestStatus.WAITING
|
||||
return True
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_STREAMING_REQ:
|
||||
assert not request.streaming_queue
|
||||
return False
|
||||
|
||||
raise AssertionError(
|
||||
"Unexpected blocked waiting status in promotion: "
|
||||
f"{request.status.name} for request {request.request_id}"
|
||||
)
|
||||
|
||||
def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
|
||||
"""
|
||||
@@ -2172,7 +2209,7 @@ class Scheduler(SchedulerInterface):
|
||||
# handle async KV loads (not cached yet, evict_blocks=False)
|
||||
async_load_reqs = (
|
||||
req
|
||||
for req in self.waiting
|
||||
for req in self.skipped_waiting
|
||||
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
)
|
||||
async_failed_req_ids, num_failed_tokens, _ = (
|
||||
|
||||
Reference in New Issue
Block a user