[BugFix][V1] Fix parallel sampling finishing/aborts (#14512)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -298,9 +298,8 @@ class AsyncLLM(EngineClient):
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort RequestId in OutputProcessor and EngineCore."""
|
||||
|
||||
request_ids = [request_id]
|
||||
request_ids = self.output_processor.abort_requests((request_id, ))
|
||||
await self.engine_core.abort_requests_async(request_ids)
|
||||
self.output_processor.abort_requests(request_ids)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Aborted request %s.", request_id)
|
||||
|
||||
@@ -137,8 +137,8 @@ class LLMEngine:
|
||||
def abort_request(self, request_ids: list[str]) -> None:
|
||||
"""Remove request_ids from EngineCore and Detokenizer."""
|
||||
|
||||
request_ids = self.output_processor.abort_requests(request_ids)
|
||||
self.engine_core.abort_requests(request_ids)
|
||||
self.output_processor.abort_requests(request_ids)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -102,8 +103,7 @@ class RequestState:
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
output_kind = self.output_kind
|
||||
final_only = output_kind == RequestOutputKind.FINAL_ONLY
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
@@ -111,24 +111,24 @@ class RequestState:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
def new_request_output(request_id: str) -> RequestOutput:
|
||||
return self._new_request_output(request_id, finished)
|
||||
|
||||
completion_output = self._new_completion_output(
|
||||
new_token_ids, finish_reason, stop_reason)
|
||||
|
||||
if self.parent_req is not None:
|
||||
return self.parent_req.make_request_output(final_only,
|
||||
completion_output,
|
||||
new_request_output)
|
||||
request_id = self.request_id
|
||||
if self.parent_req is None:
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
request_id, outputs, finished = self.parent_req.get_outputs(
|
||||
request_id, completion_output)
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
request_output = new_request_output(self.request_id)
|
||||
request_output.outputs.append(completion_output)
|
||||
return request_output
|
||||
return self._new_request_output(request_id, outputs, finished)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
request_id: str,
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
) -> RequestOutput:
|
||||
|
||||
@@ -143,7 +143,7 @@ class RequestState:
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=[],
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@@ -188,6 +188,7 @@ class OutputProcessor:
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
@@ -198,14 +199,20 @@ class OutputProcessor:
|
||||
|
||||
def abort_requests(
|
||||
self,
|
||||
request_ids: list[str],
|
||||
) -> None:
|
||||
request_ids: Iterable[str],
|
||||
) -> list[str]:
|
||||
request_ids_to_abort = []
|
||||
for request_id in request_ids:
|
||||
req_state = self.request_states.pop(request_id, None)
|
||||
if req_state is not None:
|
||||
self.lora_states.abort_request(req_state)
|
||||
if req_state.parent_req is not None:
|
||||
req_state.parent_req.finish_child_request(request_id)
|
||||
request_ids_to_abort.append(request_id)
|
||||
else:
|
||||
parent = self.parent_requests.pop(request_id, None)
|
||||
if parent and parent.child_requests:
|
||||
self.abort_requests(parent.child_requests)
|
||||
request_ids_to_abort.extend(parent.child_requests)
|
||||
return request_ids_to_abort
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@@ -227,6 +234,8 @@ class OutputProcessor:
|
||||
log_stats=self.log_stats)
|
||||
self.request_states[request_id] = req_state
|
||||
self.lora_states.add_request(req_state)
|
||||
if parent_req:
|
||||
self.parent_requests[parent_req.request_id] = parent_req
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
@@ -314,12 +323,14 @@ class OutputProcessor:
|
||||
# Free completed requests.
|
||||
if finish_reason is not None:
|
||||
self.request_states.pop(req_id)
|
||||
# Remove parent request if applicable.
|
||||
parent_req = req_state.parent_req
|
||||
if parent_req and not parent_req.child_requests:
|
||||
self.parent_requests.pop(parent_req.request_id, None)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
if req_state.parent_req is not None:
|
||||
req_state.parent_req.finish_child_request(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class ParentRequest:
|
||||
child_requests: set[str]
|
||||
|
||||
# To aggregate child completions when not streaming
|
||||
output_aggregator: Optional[RequestOutput]
|
||||
output_aggregator: list[CompletionOutput]
|
||||
|
||||
# To find the max number of generated tokens across all children
|
||||
max_num_generation_tokens: int
|
||||
@@ -37,7 +37,9 @@ class ParentRequest:
|
||||
self.sampling_params = sampling_params
|
||||
|
||||
self.child_requests = set()
|
||||
self.output_aggregator = None
|
||||
self.output_aggregator = [None] * sampling_params.n if (
|
||||
sampling_params.output_kind
|
||||
== RequestOutputKind.FINAL_ONLY) else []
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
@@ -93,43 +95,30 @@ class ParentRequest:
|
||||
"""
|
||||
child_req_id = f"{index}_{self.request_id}"
|
||||
self.child_requests.add(child_req_id)
|
||||
return (child_req_id, self._get_child_sampling_params(index))
|
||||
|
||||
def finish_child_request(self, req_id: str):
|
||||
self.child_requests.remove(req_id)
|
||||
return child_req_id, self._get_child_sampling_params(index)
|
||||
|
||||
@property
|
||||
def n(self) -> int:
|
||||
return self.sampling_params.n
|
||||
|
||||
def make_request_output(
|
||||
def get_outputs(
|
||||
self,
|
||||
final_only: bool,
|
||||
child_request_id: str,
|
||||
completion_output: CompletionOutput,
|
||||
new_request_output: Callable[[str], RequestOutput],
|
||||
) -> Optional[RequestOutput]:
|
||||
# Use an existing RequestOutput if we're aggregating
|
||||
request_output = self.output_aggregator
|
||||
) -> tuple[str, list[CompletionOutput], bool]:
|
||||
if completion_output.finished():
|
||||
self.child_requests.remove(child_request_id)
|
||||
|
||||
# Make new RequestOutput otherwise
|
||||
if request_output is None:
|
||||
request_output = new_request_output(self.request_id)
|
||||
if self.sampling_params.output_kind != RequestOutputKind.FINAL_ONLY:
|
||||
# If streaming, just return the current output.
|
||||
outputs = [completion_output]
|
||||
else:
|
||||
# If not streaming, aggregate the n final outputs.
|
||||
self.output_aggregator[completion_output.index] = completion_output
|
||||
outputs = [] if self.child_requests else self.output_aggregator
|
||||
|
||||
# Add a new completion
|
||||
request_output.outputs.append(completion_output)
|
||||
|
||||
# If not streaming, aggregate until all child requests complete
|
||||
if final_only and len(request_output.outputs) != self.n:
|
||||
self.output_aggregator = request_output
|
||||
return None
|
||||
|
||||
# We're done aggregating
|
||||
self.output_aggregator = None
|
||||
|
||||
# Parent completion output list must be sorted by index
|
||||
request_output.outputs = sorted(request_output.outputs,
|
||||
key=lambda x: x.index)
|
||||
return request_output
|
||||
finished = not self.child_requests
|
||||
return self.request_id, outputs, finished
|
||||
|
||||
def observe_num_generation_tokens(self, num_generation_tokens: int):
|
||||
self.max_num_generation_tokens = max(num_generation_tokens,
|
||||
|
||||
Reference in New Issue
Block a user