[OpenAI] Fix tool_choice=required streaming when output has trailing extra data (#31610)
Signed-off-by: maylikenoother <ogedengbemary19@gmail.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -311,6 +311,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
|
||||
combined_messages = "["
|
||||
for message in messages:
|
||||
if message.tool_calls[0].function.name:
|
||||
@@ -328,3 +329,35 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
combined_messages += "}]"
|
||||
assert json.loads(combined_messages) == output
|
||||
assert json.dumps(json.loads(combined_messages)) == output_json
|
||||
|
||||
|
||||
def test_streaming_output_valid_with_trailing_extra_data():
|
||||
self = MagicMock()
|
||||
|
||||
output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
|
||||
output_json = json.dumps(output) + "\nDONE"
|
||||
|
||||
previous_text = ""
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
delta_len = 3
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
OpenAIServingChat.extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
|
||||
previous_text = current_text
|
||||
|
||||
assert len(messages) > 0
|
||||
|
||||
Reference in New Issue
Block a user