[Feature][Responses API] Stream Function Call - harmony (#24317)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -16,6 +16,22 @@ from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
GET_WEATHER_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@@ -305,6 +321,54 @@ async def test_streaming_types(client: OpenAI, model_name: str):
|
||||
assert len(stack_of_event_types) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_with_streaming_types(client: OpenAI, model_name: str):
|
||||
# this links the "done" type with the "start" type
|
||||
# so every "done" type should have a corresponding "start" type
|
||||
# and every open block should be closed by the end of the stream
|
||||
pairs_of_event_types = {
|
||||
"response.completed": "response.created",
|
||||
"response.output_item.done": "response.output_item.added",
|
||||
"response.output_text.done": "response.output_text.delta",
|
||||
"response.reasoning_text.done": "response.reasoning_text.delta",
|
||||
"response.reasoning_part.done": "response.reasoning_part.added",
|
||||
"response.function_call_arguments.done": "response.function_call_arguments.delta", # noqa
|
||||
}
|
||||
|
||||
tools = [GET_WEATHER_SCHEMA]
|
||||
input_list = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
}
|
||||
]
|
||||
stream_response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=input_list,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
stack_of_event_types = []
|
||||
async for event in stream_response:
|
||||
if event.type == "response.created":
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type == "response.completed":
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
if event.type.endswith("added"):
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("delta"):
|
||||
if stack_of_event_types[-1] == event.type:
|
||||
continue
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("done"):
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[event.type]
|
||||
stack_of_event_types.pop()
|
||||
assert len(stack_of_event_types) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("background", [True, False])
|
||||
@@ -483,23 +547,7 @@ def call_function(name, args):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling(client: OpenAI, model_name: str):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
tools = [GET_WEATHER_SCHEMA]
|
||||
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
@@ -565,21 +613,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
GET_WEATHER_SCHEMA,
|
||||
]
|
||||
|
||||
response = await client.responses.create(
|
||||
@@ -643,23 +677,7 @@ async def test_function_calling_multi_turn(client: OpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_required(client: OpenAI, model_name: str):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
tools = [GET_WEATHER_SCHEMA]
|
||||
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.responses.create(
|
||||
@@ -689,23 +707,7 @@ async def test_system_message_with_tools(client: OpenAI, model_name: str):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": "get_weather",
|
||||
"description": "Get current temperature for provided coordinates in celsius.", # noqa
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latitude": {"type": "number"},
|
||||
"longitude": {"type": "number"},
|
||||
},
|
||||
"required": ["latitude", "longitude"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
}
|
||||
]
|
||||
tools = [GET_WEATHER_SCHEMA]
|
||||
|
||||
input_messages = [
|
||||
{"role": "user", "content": "What's the weather like in Paris today?"}
|
||||
@@ -745,6 +747,74 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
assert response_2.output_text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_function_calling_with_stream(client: OpenAI, model_name: str):
|
||||
tools = [GET_WEATHER_SCHEMA]
|
||||
input_list = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris today?",
|
||||
}
|
||||
]
|
||||
stream_response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=input_list,
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
assert stream_response is not None
|
||||
final_tool_calls = {}
|
||||
final_tool_calls_named = {}
|
||||
async for event in stream_response:
|
||||
if event.type == "response.output_item.added":
|
||||
if event.item.type != "function_call":
|
||||
continue
|
||||
final_tool_calls[event.output_index] = event.item
|
||||
final_tool_calls_named[event.item.name] = event.item
|
||||
elif event.type == "response.function_call_arguments.delta":
|
||||
index = event.output_index
|
||||
tool_call = final_tool_calls[index]
|
||||
if tool_call:
|
||||
tool_call.arguments += event.delta
|
||||
final_tool_calls_named[tool_call.name] = tool_call
|
||||
elif event.type == "response.function_call_arguments.done":
|
||||
assert event.arguments == final_tool_calls_named[event.name].arguments
|
||||
for tool_call in final_tool_calls.values():
|
||||
if (
|
||||
tool_call
|
||||
and tool_call.type == "function_call"
|
||||
and tool_call.name == "get_weather"
|
||||
):
|
||||
args = json.loads(tool_call.arguments)
|
||||
result = call_function(tool_call.name, args)
|
||||
input_list += [tool_call]
|
||||
break
|
||||
assert result is not None
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=input_list
|
||||
+ [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call.call_id,
|
||||
"output": str(result),
|
||||
}
|
||||
],
|
||||
tools=tools,
|
||||
stream=True,
|
||||
)
|
||||
assert response is not None
|
||||
async for event in response:
|
||||
# check that no function call events in the stream
|
||||
assert event.type != "response.function_call_arguments.delta"
|
||||
assert event.type != "response.function_call_arguments.done"
|
||||
# check that the response contains output text
|
||||
if event.type == "response.completed":
|
||||
assert len(event.response.output) > 0
|
||||
assert event.response.output_text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_output_messages_enabled(client: OpenAI, model_name: str, server):
|
||||
|
||||
Reference in New Issue
Block a user