[OpenAI] Fix tool_choice=required streaming when output has trailing extra data (#31610)
Signed-off-by: maylikenoother <ogedengbemary19@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
|
||||
combined_messages = "["
|
||||
for message in messages:
|
||||
if message.tool_calls[0].function.name:
|
||||
@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
|
||||
|
||||
def test_streaming_output_valid_with_trailing_extra_data():
|
||||
self = MagicMock()
|
||||
|
||||
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
|
||||
output_json = json.dumps(output) + "\nDONE"
|
||||
|
||||
previous_text = ""
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
delta_len = 3
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
|
||||
@@ -13,6 +13,7 @@ import partial_json_parser
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
@@ -76,6 +77,7 @@ from vllm.tokenizers.mistral import (
|
||||
)
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
|
||||
from vllm.tool_parsers.utils import partial_json_loads
|
||||
from vllm.utils.collection_utils import as_list
|
||||
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
|
||||
|
||||
@@ -511,8 +513,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# if the current text is empty, we cannot parse it
|
||||
return None, function_name_returned
|
||||
try:
|
||||
obj = partial_json_parser.loads(current_text)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
flags = Allow.ALL
|
||||
obj, _ = partial_json_loads(current_text, flags)
|
||||
except (
|
||||
partial_json_parser.core.exceptions.MalformedJSON,
|
||||
json.JSONDecodeError,
|
||||
):
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
obj = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user