[Sched] Enhance the logic to remove stopped requests from queues (#20739)

This commit is contained in:
Woosuk Kwon
2025-07-12 15:33:13 -07:00
committed by GitHub
parent 6e2c176e1f
commit f45a332886
3 changed files with 92 additions and 17 deletions

View File

@@ -747,19 +747,21 @@ class Scheduler(SchedulerInterface):
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
new_running: list[Request] = []
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: Optional[SpecDecodingStats] = None
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for request in self.running:
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0:
# The request was not scheduled in this step.
new_running.append(request)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# to avoid expensive operations inside the loop.
stopped_running_reqs: set[Request] = set()
stopped_preempted_reqs: set[Request] = set()
for req_id, num_tokens_scheduled in num_scheduled_tokens.items():
assert num_tokens_scheduled > 0
request = self.requests.get(req_id)
if request is None:
# The request is already finished. This can happen if the
# request is aborted while the model is executing it (e.g.,
# in pipeline parallelism).
continue
req_index = model_runner_output.req_id_to_index[req_id]
@@ -792,6 +794,7 @@ class Scheduler(SchedulerInterface):
new_logprobs = None
new_token_ids = generated_token_ids
kv_transfer_params = None
status_before_stop = request.status
# Append generated tokens and check for stop. Note that if
# a request is still being prefilled, we expect the model runner
@@ -803,17 +806,22 @@ class Scheduler(SchedulerInterface):
# This must be called before we make the EngineCoreOutput.
stopped = check_stop(request, self.max_model_len)
if stopped:
kv_transfer_params = self._free_request(request)
del new_token_ids[num_new:] # Trim new tokens if needed.
break
# Stop checking for pooler models.
pooler_output = None
if pooler_outputs:
pooler_output = pooler_outputs[req_index]
stopped = check_stop(request, self.max_model_len,
pooler_output)
if stopped:
kv_transfer_params = self._free_request(request)
if stopped:
kv_transfer_params = self._free_request(request)
if status_before_stop == RequestStatus.RUNNING:
stopped_running_reqs.add(request)
else:
stopped_preempted_reqs.add(request)
# Extract sample logprobs if needed.
if request.sampling_params is not None \
@@ -868,9 +876,14 @@ class Scheduler(SchedulerInterface):
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
if not stopped:
new_running.append(request)
self.running = new_running
# Remove the stopped requests from the running and waiting queues.
if stopped_running_reqs:
self.running = [
req for req in self.running if req not in stopped_running_reqs
]
if stopped_preempted_reqs:
# This is a rare case and unlikely to impact performance.
self.waiting.remove_requests(stopped_preempted_reqs)
# KV Connector: update state for finished KV Transfers.
self._update_from_kv_xfer_finished(model_runner_output)