[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-12 13:29:48 -04:00
committed by GitHub
parent 916836bbfb
commit f5d3acd474
7 changed files with 137 additions and 113 deletions

View File

@@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
async def abort(self, request_id: str) -> None:
"""Abort RequestId in OutputProcessor and EngineCore."""
request_ids = [request_id]
request_ids = self.output_processor.abort_requests((request_id, ))
await self.engine_core.abort_requests_async(request_ids)
self.output_processor.abort_requests(request_ids)
if self.log_requests:
logger.info("Aborted request %s.", request_id)

View File

@@ -137,8 +137,8 @@ class LLMEngine:
def abort_request(self, request_ids: list[str]) -> None:
"""Remove request_ids from EngineCore and Detokenizer."""
request_ids = self.output_processor.abort_requests(request_ids)
self.engine_core.abort_requests(request_ids)
self.output_processor.abort_requests(request_ids)
def add_request(
self,

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Optional, Union
@@ -102,8 +103,7 @@ class RequestState:
) -> Optional[RequestOutput]:
finished = finish_reason is not None
output_kind = self.output_kind
final_only = output_kind == RequestOutputKind.FINAL_ONLY
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
@@ -111,24 +111,24 @@ class RequestState:
# Only the final output is required in FINAL_ONLY mode.
return None
def new_request_output(request_id: str) -> RequestOutput:
return self._new_request_output(request_id, finished)
completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
if self.parent_req is not None:
return self.parent_req.make_request_output(final_only,
completion_output,
new_request_output)
request_id = self.request_id
if self.parent_req is None:
outputs = [completion_output]
else:
request_id, outputs, finished = self.parent_req.get_outputs(
request_id, completion_output)
if not outputs:
return None
request_output = new_request_output(self.request_id)
request_output.outputs.append(completion_output)
return request_output
return self._new_request_output(request_id, outputs, finished)
def _new_request_output(
self,
request_id: str,
outputs: list[CompletionOutput],
finished: bool,
) -> RequestOutput:
@@ -143,7 +143,7 @@ class RequestState:
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
outputs=[],
outputs=outputs,
finished=finished,
)
@@ -188,6 +188,7 @@ class OutputProcessor:
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
def get_num_unfinished_requests(self):
@@ -198,14 +199,20 @@ class OutputProcessor:
def abort_requests(
self,
request_ids: list[str],
) -> None:
request_ids: Iterable[str],
) -> list[str]:
request_ids_to_abort = []
for request_id in request_ids:
req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
if req_state.parent_req is not None:
req_state.parent_req.finish_child_request(request_id)
request_ids_to_abort.append(request_id)
else:
parent = self.parent_requests.pop(request_id, None)
if parent and parent.child_requests:
self.abort_requests(parent.child_requests)
request_ids_to_abort.extend(parent.child_requests)
return request_ids_to_abort
def add_request(
self,
@@ -227,6 +234,8 @@ class OutputProcessor:
log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
if parent_req:
self.parent_requests[parent_req.request_id] = parent_req
def process_outputs(
self,
@@ -314,12 +323,14 @@ class OutputProcessor:
# Free completed requests.
if finish_reason is not None:
self.request_states.pop(req_id)
# Remove parent request if applicable.
parent_req = req_state.parent_req
if parent_req and not parent_req.child_requests:
self.parent_requests.pop(parent_req.request_id, None)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
if req_state.parent_req is not None:
req_state.parent_req.finish_child_request(req_id)
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,

View File

@@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from copy import copy
from typing import Callable, Optional, Union
from typing import Optional, Union
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats
@@ -23,7 +23,7 @@ class ParentRequest:
child_requests: set[str]
# To aggregate child completions when not streaming
output_aggregator: Optional[RequestOutput]
output_aggregator: list[CompletionOutput]
# To find the max number of generated tokens across all children
max_num_generation_tokens: int
@@ -37,7 +37,9 @@ class ParentRequest:
self.sampling_params = sampling_params
self.child_requests = set()
self.output_aggregator = None
self.output_aggregator = [None] * sampling_params.n if (
sampling_params.output_kind
== RequestOutputKind.FINAL_ONLY) else []
self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None
@@ -93,43 +95,30 @@ class ParentRequest:
"""
child_req_id = f"{index}_{self.request_id}"
self.child_requests.add(child_req_id)
return (child_req_id, self._get_child_sampling_params(index))
def finish_child_request(self, req_id: str):
self.child_requests.remove(req_id)
return child_req_id, self._get_child_sampling_params(index)
@property
def n(self) -> int:
return self.sampling_params.n
def make_request_output(
def get_outputs(
self,
final_only: bool,
child_request_id: str,
completion_output: CompletionOutput,
new_request_output: Callable[[str], RequestOutput],
) -> Optional[RequestOutput]:
# Use an existing RequestOutput if we're aggregating
request_output = self.output_aggregator
) -> tuple[str, list[CompletionOutput], bool]:
if completion_output.finished():
self.child_requests.remove(child_request_id)
# Make new RequestOutput otherwise
if request_output is None:
request_output = new_request_output(self.request_id)
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
# If streaming, just return the current output.
outputs = [completion_output]
else:
# If not streaming, aggregate the n final outputs.
self.output_aggregator[completion_output.index] = completion_output
outputs = [] if self.child_requests else self.output_aggregator
# Add a new completion
request_output.outputs.append(completion_output)
# If not streaming, aggregate until all child requests complete
if final_only and len(request_output.outputs) != self.n:
self.output_aggregator = request_output
return None
# We're done aggregating
self.output_aggregator = None
# Parent completion output list must be sorted by index
request_output.outputs = sorted(request_output.outputs,
key=lambda x: x.index)
return request_output
finished = not self.child_requests
return self.request_id, outputs, finished
def observe_num_generation_tokens(self, num_generation_tokens: int):
self.max_num_generation_tokens = max(num_generation_tokens,