Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,9 +7,13 @@ from typing import Optional
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
|
||||
WEATHER_TOOL, ServerConfig)
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
)
|
||||
|
||||
|
||||
# test: getting the model to generate parallel tool calls (streaming/not)
|
||||
@@ -17,12 +21,15 @@ from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
||||
# parallel tool calls.
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
|
||||
async def test_parallel_tool_calls(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip("The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
@@ -32,7 +39,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
@@ -69,7 +77,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
@@ -80,24 +89,22 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
@@ -110,8 +117,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
assert (isinstance(tool_call.id, str)
|
||||
and (len(tool_call.id) >= 9))
|
||||
assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
@@ -125,32 +131,32 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
tool_call_args[
|
||||
tool_call.index] += tool_call.function.arguments
|
||||
tool_call_args[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert role_name == "assistant"
|
||||
|
||||
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
|
||||
len(tool_call_args))
|
||||
assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
|
||||
|
||||
for i in range(2):
|
||||
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
|
||||
streamed_args = json.loads(tool_call_args[i])
|
||||
non_streamed_args = json.loads(
|
||||
non_streamed_tool_calls[i].function.arguments)
|
||||
non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
|
||||
assert streamed_args == non_streamed_args
|
||||
|
||||
|
||||
# test: providing parallel tool calls back to the model to get a response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
|
||||
async def test_parallel_tool_calls_with_results(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip("The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
@@ -160,14 +166,14 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # Dallas temp in tool response
|
||||
assert "78" in choice.message.content # Orlando temp in tool response
|
||||
@@ -179,7 +185,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
|
||||
Reference in New Issue
Block a user