[Bugfix] fix streaming final output for non harmony (#30237)
Signed-off-by: penfree <qiupengfei@baidu.com> Co-authored-by: penfree <qiupengfei@baidu.com>
This commit is contained in:
@@ -2,11 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
|
||||
|
||||
def __init__(self):
|
||||
self.last_output = None
|
||||
|
||||
# Accumulated final output for streaming mode
|
||||
self._accumulated_text: str = ""
|
||||
self._accumulated_token_ids: list[int] = []
|
||||
self._accumulated_logprobs: list = []
|
||||
|
||||
self.num_prompt_tokens = 0
|
||||
self.num_output_tokens = 0
|
||||
self.num_cached_tokens = 0
|
||||
@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
|
||||
self.num_cached_tokens = output.num_cached_tokens or 0
|
||||
self.num_output_tokens += len(output.outputs[0].token_ids or [])
|
||||
|
||||
# Accumulate text, token_ids, and logprobs for streaming mode
|
||||
delta_output = output.outputs[0]
|
||||
self._accumulated_text += delta_output.text
|
||||
self._accumulated_token_ids.extend(delta_output.token_ids)
|
||||
if delta_output.logprobs is not None:
|
||||
self._accumulated_logprobs.extend(delta_output.logprobs)
|
||||
|
||||
if len(self.input_messages) == 0:
|
||||
output_prompt = output.prompt or ""
|
||||
output_prompt_token_ids = output.prompt_token_ids or []
|
||||
@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
|
||||
)
|
||||
self.output_messages.append(
|
||||
ResponseRawMessageAndToken(
|
||||
message=output.outputs[0].text,
|
||||
tokens=output.outputs[0].token_ids,
|
||||
message=delta_output.text,
|
||||
tokens=delta_output.token_ids,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def final_output(self) -> RequestOutput | None:
|
||||
"""Return the final output, with complete text/token_ids/logprobs."""
|
||||
if self.last_output is not None and self.last_output.outputs:
|
||||
assert isinstance(self.last_output, RequestOutput)
|
||||
final_output = copy.copy(self.last_output)
|
||||
# copy inner item to avoid modify last_output
|
||||
final_output.outputs = [replace(item) for item in self.last_output.outputs]
|
||||
final_output.outputs[0].text = self._accumulated_text
|
||||
final_output.outputs[0].token_ids = tuple(self._accumulated_token_ids)
|
||||
if self._accumulated_logprobs:
|
||||
final_output.outputs[0].logprobs = self._accumulated_logprobs
|
||||
return final_output
|
||||
return self.last_output
|
||||
|
||||
def append_tool_output(self, output) -> None:
|
||||
raise NotImplementedError("Should not be called.")
|
||||
|
||||
|
||||
@@ -675,7 +675,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
num_tool_output_tokens = 0
|
||||
else:
|
||||
assert isinstance(context, SimpleContext)
|
||||
final_res = context.last_output
|
||||
# Use final_output which has accumulated text/token_ids/logprobs
|
||||
final_res = context.final_output
|
||||
assert final_res is not None
|
||||
assert len(final_res.outputs) == 1
|
||||
final_output = final_res.outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user