[OpenAI] Fix tool_choice=required streaming when output has trailing extra data (#31610)

Signed-off-by: maylikenoother <ogedengbemary19@gmail.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Mary
2026-01-08 13:01:42 +00:00
committed by GitHub
parent 1123a87892
commit 7645bc524b
2 changed files with 41 additions and 2 deletions

View File

@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
previous_text = current_text
assert len(messages) > 0
combined_messages = "["
for message in messages:
if message.tool_calls[0].function.name:
@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
combined_messages += "}]"
assert json.loads(combined_messages) == output
assert json.dumps(json.loads(combined_messages)) == output_json
def test_streaming_output_valid_with_trailing_extra_data():
self = MagicMock()
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
output_json = json.dumps(output) + "\nDONE"
previous_text = ""
function_name_returned = False
messages = []
delta_len = 3
for i in range(0, len(output_json), delta_len):
delta_text = output_json[i : i + delta_len]
current_text = previous_text + delta_text
delta_message, function_name_returned = (
OpenAIServingChat.extract_tool_call_required_streaming(
self,
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=function_name_returned,
)
)
if delta_message:
messages.append(delta_message)
previous_text = current_text
assert len(messages) > 0

View File

@@ -13,6 +13,7 @@ import partial_json_parser
import regex as re
from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from partial_json_parser.core.options import Allow
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
@@ -76,6 +77,7 @@ from vllm.tokenizers.mistral import (
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.tool_parsers.utils import partial_json_loads
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
@@ -511,8 +513,12 @@ class OpenAIServingChat(OpenAIServing):
# if the current text is empty, we cannot parse it
return None, function_name_returned
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
flags = Allow.ALL
obj, _ = partial_json_loads(current_text, flags)
except (
partial_json_parser.core.exceptions.MalformedJSON,
json.JSONDecodeError,
):
logger.debug("not enough tokens to parse into JSON yet")
obj = None