[Core] feat: Implement Priority Scheduling in V1 Engine (#19057)
Signed-off-by: amit <amit.man@gmail.com> Co-authored-by: Roger Wang <Rogerw0108@gmail.com>
This commit is contained in:
224
vllm/v1/core/sched/request_queue.py
Normal file
224
vllm/v1/core/sched/request_queue.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from enum import Enum
|
||||
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SchedulingPolicy(Enum):
|
||||
"""Enum for scheduling policies."""
|
||||
FCFS = "fcfs"
|
||||
PRIORITY = "priority"
|
||||
|
||||
|
||||
class RequestQueue(ABC):
|
||||
"""Abstract base class for request queues."""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the request at the front of the queue without removing it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
pass
|
||||
|
||||
|
||||
class FCFSRequestQueue(deque[Request], RequestQueue):
|
||||
"""A first-come-first-served queue that supports deque operations."""
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to FCFS policy."""
|
||||
self.append(request)
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to FCFS policy."""
|
||||
return self.popleft()
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self:
|
||||
raise IndexError("peek from an empty queue")
|
||||
return self[0]
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
self.appendleft(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
self.extendleft(reversed(requests))
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self.remove(request)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = set(requests)
|
||||
filtered_requests = [
|
||||
req for req in self if req not in requests_to_remove
|
||||
]
|
||||
# deque does not support in-place filtering, so we need to clear
|
||||
# and extend
|
||||
self.clear()
|
||||
self.extend(filtered_requests)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return len(self) > 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return super().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to FCFS policy."""
|
||||
return super().__iter__()
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
return super().__reversed__()
|
||||
|
||||
|
||||
class PriorityRequestQueue(RequestQueue):
|
||||
"""
|
||||
A priority queue that supports heap operations.
|
||||
|
||||
Requests with a smaller value of `priority` are processed first.
|
||||
If multiple requests have the same priority, the one with the earlier
|
||||
`arrival_time` is processed first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._heap: list[tuple[int, float, Request]] = []
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy."""
|
||||
heapq.heappush(self._heap,
|
||||
(request.priority, request.arrival_time, request))
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to priority policy."""
|
||||
if not self._heap:
|
||||
raise IndexError("pop from empty heap")
|
||||
_, _, request = heapq.heappop(self._heap)
|
||||
return request
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self._heap:
|
||||
raise IndexError("peek from empty heap")
|
||||
_, _, request = self._heap[0]
|
||||
return request
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
self.add_request(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Add all requests from another queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
for request in requests:
|
||||
self.add_request(request)
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self._heap = [(p, t, r) for p, t, r in self._heap if r != request]
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = set(requests)
|
||||
self._heap = [(p, t, r) for p, t, r in self._heap
|
||||
if r not in requests_to_remove]
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return bool(self._heap)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return len(self._heap)
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to priority policy."""
|
||||
heap_copy = self._heap[:]
|
||||
while heap_copy:
|
||||
_, _, request = heapq.heappop(heap_copy)
|
||||
yield request
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse priority order."""
|
||||
return reversed(list(self))
|
||||
|
||||
|
||||
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
||||
"""Create request queue based on scheduling policy."""
|
||||
if policy == SchedulingPolicy.PRIORITY:
|
||||
return PriorityRequestQueue()
|
||||
elif policy == SchedulingPolicy.FCFS:
|
||||
return FCFSRequestQueue()
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduling policy: {policy}")
|
||||
@@ -22,6 +22,8 @@ from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
|
||||
create_request_queue)
|
||||
from vllm.v1.core.sched.utils import check_stop
|
||||
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
@@ -94,8 +96,16 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: dict[str, Request] = {}
|
||||
# Scheduling policy
|
||||
if self.scheduler_config.policy == "priority":
|
||||
self.policy = SchedulingPolicy.PRIORITY
|
||||
elif self.scheduler_config.policy == "fcfs":
|
||||
self.policy = SchedulingPolicy.FCFS
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheduling policy: {self.scheduler_config.policy}")
|
||||
# Priority queues for requests.
|
||||
self.waiting: deque[Request] = deque()
|
||||
self.waiting = create_request_queue(self.policy)
|
||||
self.running: list[Request] = []
|
||||
|
||||
# The request IDs that are finished in between the previous and the
|
||||
@@ -247,7 +257,15 @@ class Scheduler(SchedulerInterface):
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
preempted_req = self.running.pop()
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
@@ -255,7 +273,7 @@ class Scheduler(SchedulerInterface):
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
|
||||
|
||||
self.waiting.appendleft(preempted_req)
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
@@ -311,9 +329,9 @@ class Scheduler(SchedulerInterface):
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
assert len(scheduled_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary deque to collect requests that need to be skipped
|
||||
# and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests: deque[Request] = deque()
|
||||
# Use a temporary RequestQueue to collect requests that need to be
|
||||
# skipped and put back at the head of the waiting queue later
|
||||
skipped_waiting_requests = create_request_queue(self.policy)
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
@@ -321,7 +339,7 @@ class Scheduler(SchedulerInterface):
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
request = self.waiting.peek_request()
|
||||
|
||||
# KVTransfer: skip request if still waiting for remote kvs.
|
||||
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||
@@ -332,8 +350,8 @@ class Scheduler(SchedulerInterface):
|
||||
logger.debug(
|
||||
"%s is still in WAITING_FOR_REMOTE_KVS state.",
|
||||
request.request_id)
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Skip request if the structured output request is still waiting
|
||||
@@ -343,19 +361,18 @@ class Scheduler(SchedulerInterface):
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request and (
|
||||
len(scheduled_loras) == self.lora_config.max_loras
|
||||
and request.lora_request.lora_int_id
|
||||
not in scheduled_loras):
|
||||
if (self.lora_config and request.lora_request and
|
||||
(len(scheduled_loras) == self.lora_config.max_loras and
|
||||
request.lora_request.lora_int_id not in scheduled_loras)):
|
||||
# Scheduling would exceed max_loras, skip.
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_external_computed_tokens = 0
|
||||
@@ -407,8 +424,8 @@ class Scheduler(SchedulerInterface):
|
||||
# pooling requests to be chunked
|
||||
if not self.scheduler_config.chunked_prefill_enabled and \
|
||||
num_new_tokens > token_budget:
|
||||
self.waiting.popleft()
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
self.waiting.pop_request()
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
continue
|
||||
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
@@ -448,17 +465,19 @@ class Scheduler(SchedulerInterface):
|
||||
num_external_computed_tokens,
|
||||
)
|
||||
|
||||
self.waiting.popleft()
|
||||
# Request was already popped from self.waiting
|
||||
# unless it was re-added above due to new_blocks being None.
|
||||
request = self.waiting.pop_request()
|
||||
if load_kv_async:
|
||||
# If loading async, allocate memory and put request
|
||||
# into the WAITING_FOR_REMOTE_KV state.
|
||||
skipped_waiting_requests.appendleft(request)
|
||||
skipped_waiting_requests.prepend_request(request)
|
||||
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||
continue
|
||||
|
||||
if request.use_structured_output:
|
||||
structured_output_request_ids[
|
||||
request.request_id] = req_index
|
||||
structured_output_request_ids[request.request_id] = (
|
||||
req_index)
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
if self.log_stats:
|
||||
@@ -494,7 +513,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.extendleft(skipped_waiting_requests)
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
@@ -896,7 +915,7 @@ class Scheduler(SchedulerInterface):
|
||||
return len(self.running), len(self.waiting)
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
self.waiting.add_request(request)
|
||||
self.requests[request.request_id] = request
|
||||
if self.log_stats:
|
||||
request.record_event(EngineCoreEventType.QUEUED)
|
||||
@@ -917,16 +936,31 @@ class Scheduler(SchedulerInterface):
|
||||
else:
|
||||
request_ids = set(request_ids)
|
||||
|
||||
running_requests_to_remove = []
|
||||
waiting_requests_to_remove = []
|
||||
valid_requests = []
|
||||
|
||||
# First pass: collect requests to remove from queues
|
||||
for req_id in request_ids:
|
||||
request = self.requests.get(req_id)
|
||||
if request is None:
|
||||
# Invalid request ID.
|
||||
continue
|
||||
|
||||
valid_requests.append(request)
|
||||
if request.status == RequestStatus.RUNNING:
|
||||
self.running.remove(request)
|
||||
running_requests_to_remove.append(request)
|
||||
else:
|
||||
self.waiting.remove(request)
|
||||
waiting_requests_to_remove.append(request)
|
||||
|
||||
# Remove all requests from queues at once for better efficiency
|
||||
for request in running_requests_to_remove:
|
||||
self.running.remove(request)
|
||||
if waiting_requests_to_remove:
|
||||
self.waiting.remove_requests(waiting_requests_to_remove)
|
||||
|
||||
# Second pass: set status and free requests
|
||||
for request in valid_requests:
|
||||
request.status = finished_status
|
||||
self._free_request(request)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user