Implement Async Scheduling (#19970)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-07-14 23:01:46 -07:00
committed by GitHub
parent 85bd6599e4
commit d4d309409f
11 changed files with 508 additions and 148 deletions

View File

@@ -0,0 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class AsyncScheduler(Scheduler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
if (request.num_computed_tokens == request.num_tokens +
request.num_output_placeholders):
# The request will generate a new token in this scheduling step.
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids)
# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0
# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request,
request.num_computed_tokens - request.num_output_placeholders)
return new_token_ids, stopped

View File

@@ -204,7 +204,8 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec -
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
@@ -230,9 +231,11 @@ class Scheduler(SchedulerInterface):
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
@@ -598,6 +601,14 @@ class Scheduler(SchedulerInterface):
request = self.requests[req_id]
request.num_computed_tokens += num_scheduled_token
# NOTE: _free_encoder_inputs relies on num_computed_tokens, which
# may be updated again in _update_from_output for speculative
# decoding. However, it is safe to call the method here because
# encoder inputs are always part of the prompt, not the output,
# and thus are unaffected by speculative decoding.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
# Clear the finished request IDs.
# NOTE: We shouldn't do self.finished_req_ids.clear() here because
# it will also affect the scheduler output.
@@ -785,29 +796,16 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
# NOTE(woosuk): This has to be executed after updating
# `request.num_computed_tokens`.
if request.has_encoder_inputs:
self._free_encoder_inputs(request)
stopped = False
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
status_before_stop = request.status
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Check for stop and update request status.
if new_token_ids:
new_token_ids, stopped = self._update_request_with_output(
request, new_token_ids)
# Stop checking for pooler models.
pooler_output = None
@@ -915,6 +913,26 @@ class Scheduler(SchedulerInterface):
return engine_core_outputs
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
# to return empty token ids for the request.
stopped = False
for num_new, output_token_id in enumerate(new_token_ids, 1):
request.append_output_token_ids(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
del new_token_ids[num_new:] # Trim new tokens if needed.
break
return new_token_ids, stopped
def _free_encoder_inputs(self, request: Request) -> None:
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))