Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,8 +7,12 @@ from typing import Optional
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL, WEATHER_TOOL)
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_TOOLS,
|
||||
MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
)
|
||||
|
||||
|
||||
# test: request a chat completion that should return tool calls, so we know they
|
||||
@@ -23,17 +27,18 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure a tool call is present
|
||||
assert choice.message.role == 'assistant'
|
||||
assert choice.message.role == "assistant"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].type == 'function'
|
||||
assert tool_calls[0].type == "function"
|
||||
assert tool_calls[0].function is not None
|
||||
assert isinstance(tool_calls[0].id, str)
|
||||
assert len(tool_calls[0].id) >= 9
|
||||
@@ -54,7 +59,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
function_name: Optional[str] = None
|
||||
function_args_str: str = ''
|
||||
function_args_str: str = ""
|
||||
tool_call_id: Optional[str] = None
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
@@ -67,20 +72,21 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.choices[0].index == 0
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
@@ -108,7 +114,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
function_args_str += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert role_name == "assistant"
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
|
||||
|
||||
# validate the name and arguments
|
||||
@@ -148,14 +154,14 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # the temperature from the response
|
||||
|
||||
@@ -166,7 +172,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
|
||||
Reference in New Issue
Block a user