Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,13 +13,13 @@ from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
# select models to test based on command line arguments
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--models",
|
||||
nargs="+",
|
||||
help="Specify one or more models to test")
|
||||
parser.addoption("--extended",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="invoke extended tests requiring large GPUs")
|
||||
parser.addoption("--models", nargs="+", help="Specify one or more models to test")
|
||||
parser.addoption(
|
||||
"--extended",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="invoke extended tests requiring large GPUs",
|
||||
)
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@@ -29,8 +29,10 @@ def server_config(request):
|
||||
models = request.config.getoption("--models")
|
||||
|
||||
config_keys_to_test = [
|
||||
key for key in CONFIGS if (models is None or key in models) and (
|
||||
extended or not CONFIGS[key].get("extended", False))
|
||||
key
|
||||
for key in CONFIGS
|
||||
if (models is None or key in models)
|
||||
and (extended or not CONFIGS[key].get("extended", False))
|
||||
]
|
||||
|
||||
config_key = request.param
|
||||
@@ -40,8 +42,9 @@ def server_config(request):
|
||||
config = CONFIGS[config_key]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip("The {} model can't be tested on the ROCm platform".format(
|
||||
config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model can't be tested on the ROCm platform".format(config["model"])
|
||||
)
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
@@ -53,8 +56,9 @@ def server_config(request):
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||
max_wait_seconds=480) as server:
|
||||
with RemoteOpenAIServer(
|
||||
model, ARGS + args_for_model, max_wait_seconds=480
|
||||
) as server:
|
||||
yield server
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip("The {} model can't be tested on the ROCm platform".format(
|
||||
config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model can't be tested on the ROCm platform".format(config["model"])
|
||||
)
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
@@ -30,8 +31,9 @@ def server_config(request):
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||
max_wait_seconds=480) as server:
|
||||
with RemoteOpenAIServer(
|
||||
model, ARGS + args_for_model, max_wait_seconds=480
|
||||
) as server:
|
||||
yield server
|
||||
|
||||
|
||||
|
||||
@@ -19,12 +19,12 @@ async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
tool_choice=WEATHER_TOOL,
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 1
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 1
|
||||
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
|
||||
|
||||
@@ -18,17 +18,16 @@ ARGS: list[str] = ["--max-model-len", "1024"]
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"mistral": {
|
||||
"model":
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--tokenizer-mode", "mistral",
|
||||
"--ignore-patterns=\"consolidated.safetensors\""
|
||||
"--tokenizer-mode",
|
||||
"mistral",
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
"to the user's question - just respond to it normally.",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -8,68 +8,56 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
def test_chat_completion_request_with_no_tools():
|
||||
# tools key is not present
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
})
|
||||
assert request.tool_choice == 'none'
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
# tools key is None
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tools':
|
||||
None
|
||||
})
|
||||
assert request.tool_choice == 'none'
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tools": None,
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
# tools key present but empty
|
||||
request = ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tools': []
|
||||
})
|
||||
assert request.tool_choice == 'none'
|
||||
request = ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tools": [],
|
||||
}
|
||||
)
|
||||
assert request.tool_choice == "none"
|
||||
|
||||
|
||||
@pytest.mark.parametrize('tool_choice', ['auto', 'required'])
|
||||
@pytest.mark.parametrize("tool_choice", ["auto", "required"])
|
||||
def test_chat_completion_request_with_tool_choice_but_no_tools(tool_choice):
|
||||
with pytest.raises(ValueError,
|
||||
match="When using `tool_choice`, `tools` must be set."):
|
||||
ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
tool_choice
|
||||
})
|
||||
with pytest.raises(
|
||||
ValueError, match="When using `tool_choice`, `tools` must be set."
|
||||
):
|
||||
ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tool_choice": tool_choice,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="When using `tool_choice`, `tools` must be set."):
|
||||
ChatCompletionRequest.model_validate({
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
}],
|
||||
'model':
|
||||
'facebook/opt-125m',
|
||||
'tool_choice':
|
||||
tool_choice,
|
||||
'tools':
|
||||
None
|
||||
})
|
||||
with pytest.raises(
|
||||
ValueError, match="When using `tool_choice`, `tools` must be set."
|
||||
):
|
||||
ChatCompletionRequest.model_validate(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"model": "facebook/opt-125m",
|
||||
"tool_choice": tool_choice,
|
||||
"tools": None,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -4,16 +4,21 @@
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL, ServerConfig,
|
||||
ensure_system_prompt)
|
||||
from .utils import (
|
||||
MESSAGES_WITHOUT_TOOLS,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
ensure_system_prompt,
|
||||
)
|
||||
|
||||
|
||||
# test: make sure chat completions without tools provided work even when tools
|
||||
# are enabled. This makes sure tool call chat templates work, AND that the tool
|
||||
# parser stream processing doesn't change the output of the model.
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
async def test_chat_completion_without_tools(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -21,7 +26,8 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
|
||||
temperature=0,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
@@ -32,8 +38,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
|
||||
assert stop_reason != "tool_calls"
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert (choice.message.tool_calls is None
|
||||
or len(choice.message.tool_calls) == 0)
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
@@ -55,7 +60,7 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert not role_sent
|
||||
assert delta.role == 'assistant'
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
@@ -80,8 +85,9 @@ async def test_chat_completion_without_tools(client: openai.AsyncOpenAI,
|
||||
# tools, to make sure we can still get normal chat completion responses
|
||||
# and that they won't be parsed as tools
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
async def test_chat_completion_with_tools(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@@ -90,19 +96,19 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=150,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
output_text = chat_completion.choices[0].message.content
|
||||
|
||||
# check to make sure we got text
|
||||
assert output_text is not None
|
||||
assert stop_reason != 'tool_calls'
|
||||
assert stop_reason != "tool_calls"
|
||||
assert len(output_text) > 0
|
||||
|
||||
# check to make sure no tool calls were returned
|
||||
assert (choice.message.tool_calls is None
|
||||
or len(choice.message.tool_calls) == 0)
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
|
||||
# make the same request, streaming
|
||||
stream = await client.chat.completions.create(
|
||||
@@ -125,7 +131,7 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
|
||||
|
||||
# make sure the role is assistant
|
||||
if delta.role:
|
||||
assert delta.role == 'assistant'
|
||||
assert delta.role == "assistant"
|
||||
role_sent = True
|
||||
|
||||
if delta.content:
|
||||
@@ -142,6 +148,6 @@ async def test_chat_completion_with_tools(client: openai.AsyncOpenAI,
|
||||
assert role_sent
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == stop_reason
|
||||
assert chunk.choices[0].finish_reason != 'tool_calls'
|
||||
assert chunk.choices[0].finish_reason != "tool_calls"
|
||||
assert len(chunks)
|
||||
assert "".join(chunks) == output_text
|
||||
|
||||
@@ -21,23 +21,28 @@ def parser(deepseekv31_tokenizer):
|
||||
|
||||
def test_extract_tool_calls_with_tool(parser):
|
||||
model_output = (
|
||||
"normal text" + "<|tool▁calls▁begin|>" +
|
||||
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||
"<|tool▁calls▁end|>")
|
||||
"normal text"
|
||||
+ "<|tool▁calls▁begin|>"
|
||||
+ '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>'
|
||||
+ "<|tool▁calls▁end|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||
assert result.tool_calls[0].function.arguments == '{"x":1}'
|
||||
assert result.content == "normal text"
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_multiple_tools(parser):
|
||||
model_output = (
|
||||
"some prefix text" + "<|tool▁calls▁begin|>" +
|
||||
"<|tool▁call▁begin|>foo<|tool▁sep|>{\"x\":1}<|tool▁call▁end|>" +
|
||||
"<|tool▁call▁begin|>bar<|tool▁sep|>{\"y\":2}<|tool▁call▁end|>" +
|
||||
"<|tool▁calls▁end|>" + " some suffix text")
|
||||
"some prefix text"
|
||||
+ "<|tool▁calls▁begin|>"
|
||||
+ '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>'
|
||||
+ '<|tool▁call▁begin|>bar<|tool▁sep|>{"y":2}<|tool▁call▁end|>'
|
||||
+ "<|tool▁calls▁end|>"
|
||||
+ " some suffix text"
|
||||
)
|
||||
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
@@ -45,10 +50,10 @@ def test_extract_tool_calls_with_multiple_tools(parser):
|
||||
assert len(result.tool_calls) == 2
|
||||
|
||||
assert result.tool_calls[0].function.name == "foo"
|
||||
assert result.tool_calls[0].function.arguments == "{\"x\":1}"
|
||||
assert result.tool_calls[0].function.arguments == '{"x":1}'
|
||||
|
||||
assert result.tool_calls[1].function.name == "bar"
|
||||
assert result.tool_calls[1].function.arguments == "{\"y\":2}"
|
||||
assert result.tool_calls[1].function.arguments == '{"y":2}'
|
||||
|
||||
# prefix is content
|
||||
assert result.content == "some prefix text"
|
||||
|
||||
@@ -27,12 +27,14 @@ def glm4_moe_tool_parser(glm4_moe_tokenizer):
|
||||
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
@@ -47,7 +49,8 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -73,14 +76,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
@@ -102,22 +109,30 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
@@ -131,14 +146,18 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
@@ -152,37 +171,51 @@ def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "New York",
|
||||
"state": "NY",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "New York",
|
||||
"state": "NY",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
("""I will help you get the weather.<tool_call>get_weather
|
||||
(
|
||||
"""I will help you get the weather.<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>""", [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
"date": "2025-08-01",
|
||||
}),
|
||||
))
|
||||
], "I will help you get the weather."),
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
"date": "2025-08-01",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I will help you get the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(glm4_moe_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls(
|
||||
glm4_moe_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
@@ -202,7 +235,8 @@ I will help you get the weather.
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
@@ -224,7 +258,8 @@ def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Should handle malformed XML gracefully
|
||||
# The parser should either extract what it can or return no tool calls
|
||||
@@ -239,12 +274,12 @@ def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "get_current_time"
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_time"
|
||||
# Empty arguments should result in empty JSON object
|
||||
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
|
||||
|
||||
@@ -270,7 +305,8 @@ meaningwhile, I will also check the weather in Shanghai.
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
@@ -321,8 +357,7 @@ def test_streaming_basic_functionality(glm4_moe_tool_parser):
|
||||
|
||||
# The result behavior depends on the streaming state
|
||||
# This test mainly ensures no exceptions are thrown
|
||||
assert result is None or hasattr(result, 'tool_calls') or hasattr(
|
||||
result, 'content')
|
||||
assert result is None or hasattr(result, "tool_calls") or hasattr(result, "content")
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
|
||||
@@ -341,7 +376,7 @@ def test_streaming_no_tool_calls(glm4_moe_tool_parser):
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
@@ -367,7 +402,7 @@ def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
|
||||
|
||||
# Should return content when no tool call tokens are detected
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == "get the weather.<tool_call>"
|
||||
|
||||
|
||||
@@ -383,7 +418,8 @@ def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
@@ -404,7 +440,8 @@ def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
|
||||
<arg_value>2025-08-01</arg_value>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Incomplete tool calls should not be extracted
|
||||
assert not extracted_tool_calls.tools_called
|
||||
|
||||
@@ -9,8 +9,7 @@ import partial_json_parser
|
||||
import pytest
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import JambaToolParser
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
@@ -30,12 +29,14 @@ def jamba_tool_parser(jamba_tokenizer):
|
||||
return JambaToolParser(jamba_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16
|
||||
|
||||
@@ -44,10 +45,9 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer,
|
||||
model_output: str) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = jamba_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, model_output: str
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = jamba_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
@@ -56,18 +56,19 @@ def stream_delta_message_generator(
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=jamba_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=jamba_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
@@ -84,8 +85,9 @@ def stream_delta_message_generator(
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = previous_tokens + new_tokens if previous_tokens\
|
||||
else new_tokens
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
@@ -93,7 +95,8 @@ def stream_delta_message_generator(
|
||||
def test_extract_tool_calls_no_tools(jamba_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -108,54 +111,63 @@ def test_extract_tool_calls_no_tools(jamba_tool_parser):
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None),
|
||||
None,
|
||||
),
|
||||
(
|
||||
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" Sure! let me call the tool for you."),
|
||||
" Sure! let me call the tool for you.",
|
||||
),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None)
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(jamba_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls(
|
||||
jamba_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = jamba_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -172,63 +184,75 @@ def test_extract_tool_calls(jamba_tool_parser, model_output,
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
('''This is a test''', [], '''This is a test'''),
|
||||
("""This is a test""", [], """This is a test"""),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" "),
|
||||
" ",
|
||||
),
|
||||
(
|
||||
''' Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" Sure! let me call the tool for you.<tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
" Sure! let me call the tool for you."),
|
||||
" Sure! let me call the tool for you.",
|
||||
),
|
||||
(
|
||||
''' <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>''', # noqa: E501
|
||||
""" <tool_calls>[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]</tool_calls>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
" ")
|
||||
" ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
other_content: str = ''
|
||||
def test_extract_tool_calls_streaming(
|
||||
jamba_tool_parser,
|
||||
jamba_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
other_content: str = ""
|
||||
function_names: list[str] = []
|
||||
function_args_strs: list[str] = []
|
||||
tool_call_idx: int = -1
|
||||
tool_call_ids: list[Optional[str]] = []
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
jamba_tool_parser, jamba_tokenizer, model_output):
|
||||
jamba_tool_parser, jamba_tokenizer, model_output
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
@@ -264,18 +288,22 @@ def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer,
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
function_args_strs[
|
||||
tool_call.index] += tool_call.function.arguments
|
||||
function_args_strs[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert other_content == expected_content
|
||||
|
||||
actual_tool_calls = [
|
||||
ToolCall(id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR)))
|
||||
ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(
|
||||
name=function_name,
|
||||
arguments=partial_json_parser.ensure_json(
|
||||
function_args_str, Allow.OBJ | Allow.STR
|
||||
),
|
||||
),
|
||||
)
|
||||
for tool_call_id, function_name, function_args_str in zip(
|
||||
tool_call_ids, function_names, function_args_strs)
|
||||
tool_call_ids, function_names, function_args_strs
|
||||
)
|
||||
]
|
||||
assert_tool_calls(actual_tool_calls, expected_tool_calls)
|
||||
|
||||
@@ -26,27 +26,31 @@ def kimi_k2_tool_parser(kimi_k2_tokenizer):
|
||||
return KimiK2ToolParser(kimi_k2_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
# assert tool call id format
|
||||
assert actual_tool_call.id.startswith("functions.")
|
||||
assert actual_tool_call.id.split(':')[-1].isdigit()
|
||||
assert actual_tool_call.id.split('.')[1].split(
|
||||
':')[0] == expected_tool_call.function.name
|
||||
assert actual_tool_call.id.split(":")[-1].isdigit()
|
||||
assert (
|
||||
actual_tool_call.id.split(".")[1].split(":")[0]
|
||||
== expected_tool_call.function.name
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -63,14 +67,18 @@ def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
|
||||
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(id='functions.get_weather:0',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
@@ -79,31 +87,41 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_
|
||||
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
|
||||
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
|
||||
[
|
||||
ToolCall(id='functions.get_weather:0',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
}, ),
|
||||
),
|
||||
type='function'),
|
||||
ToolCall(id='functions.get_weather:1',
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Shanghai",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
id="functions.get_weather:0",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Beijing",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
ToolCall(
|
||||
id="functions.get_weather:1",
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Shanghai",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
"I'll help you check the weather. ",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(kimi_k2_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls(
|
||||
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -118,15 +136,14 @@ functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "invalid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
1].function.name == "valid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
|
||||
@@ -136,13 +153,13 @@ functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"}
|
||||
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
|
||||
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
# Should extract only the valid JSON tool calls
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "valid_get_weather"
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(kimi_k2_tool_parser):
|
||||
@@ -170,8 +187,7 @@ functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_
|
||||
|
||||
# The result might be None or contain tool call information
|
||||
# This depends on the internal state management
|
||||
if result is not None and hasattr(result,
|
||||
'tool_calls') and result.tool_calls:
|
||||
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
|
||||
assert len(result.tool_calls) >= 0
|
||||
|
||||
|
||||
@@ -191,5 +207,5 @@ def test_streaming_no_tool_calls(kimi_k2_tool_parser):
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,9 +4,15 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from openai_harmony import (Conversation, DeveloperContent,
|
||||
HarmonyEncodingName, Message, Role, SystemContent,
|
||||
load_harmony_encoding)
|
||||
from openai_harmony import (
|
||||
Conversation,
|
||||
DeveloperContent,
|
||||
HarmonyEncodingName,
|
||||
Message,
|
||||
Role,
|
||||
SystemContent,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import OpenAIToolParser
|
||||
@@ -37,8 +43,9 @@ def assert_tool_calls(
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16 # Default from protocol.py
|
||||
assert actual_tool_call.type == "function"
|
||||
@@ -46,20 +53,25 @@ def assert_tool_calls(
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.SYSTEM,
|
||||
SystemContent.new(),
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.DEVELOPER,
|
||||
DeveloperContent.new().with_instructions("Talk like a pirate!")),
|
||||
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
|
||||
Message.from_role_and_content(Role.ASSISTANT,
|
||||
"This is a test").with_channel("final")
|
||||
])
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.SYSTEM,
|
||||
SystemContent.new(),
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.DEVELOPER,
|
||||
DeveloperContent.new().with_instructions("Talk like a pirate!"),
|
||||
),
|
||||
Message.from_role_and_content(Role.USER, "Arrr, how be you?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, "This is a test"
|
||||
).with_channel("final"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT)
|
||||
convo, Role.ASSISTANT
|
||||
)
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
request=None,
|
||||
@@ -70,26 +82,32 @@ def test_extract_tool_calls_no_tools(openai_tool_parser, harmony_encoding):
|
||||
assert extracted_info.content == "This is a test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_args", [
|
||||
'{"location": "Tokyo"}',
|
||||
'{\n"location": "Tokyo"\n}',
|
||||
])
|
||||
def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
|
||||
tool_args):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(Role.USER,
|
||||
"What is the weather in Tokyo?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
tool_args).with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"tool_args",
|
||||
[
|
||||
'{"location": "Tokyo"}',
|
||||
'{\n"location": "Tokyo"\n}',
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_single_tool(
|
||||
openai_tool_parser, harmony_encoding, tool_args
|
||||
):
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" We need to use get_current_weather tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, tool_args)
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo, Role.ASSISTANT)
|
||||
convo, Role.ASSISTANT
|
||||
)
|
||||
|
||||
extracted_info = openai_tool_parser.extract_tool_calls(
|
||||
"",
|
||||
@@ -98,10 +116,12 @@ def test_extract_tool_calls_single_tool(openai_tool_parser, harmony_encoding,
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
)
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
@@ -111,33 +131,39 @@ def test_extract_tool_calls_multiple_tools(
|
||||
openai_tool_parser,
|
||||
harmony_encoding,
|
||||
):
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_user_location").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '{"location": "Tokyo"}').with_channel(
|
||||
"commentary").with_recipient("functions.no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "foo").with_channel(
|
||||
"commentary").with_recipient("functions.not_json_no_content_type"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '{}').with_channel("commentary").with_recipient(
|
||||
"functions.empty_args").with_content_type("json"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT, '').with_channel("commentary").with_recipient(
|
||||
"functions.no_args").with_content_type("json"),
|
||||
])
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_user_location")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "foo")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.not_json_no_content_type"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "{}")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.empty_args")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, "")
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.no_args")
|
||||
.with_content_type("json"),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
@@ -150,30 +176,42 @@ def test_extract_tool_calls_multiple_tools(
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_user_location",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="no_content_type",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="not_json_no_content_type",
|
||||
arguments="foo",
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="empty_args",
|
||||
arguments=json.dumps({}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="no_args",
|
||||
arguments="",
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_user_location",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="no_content_type",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="not_json_no_content_type",
|
||||
arguments="foo",
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="empty_args",
|
||||
arguments=json.dumps({}),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="no_args",
|
||||
arguments="",
|
||||
)
|
||||
),
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content is None
|
||||
@@ -184,20 +222,24 @@ def test_extract_tool_calls_with_content(
|
||||
harmony_encoding,
|
||||
):
|
||||
final_content = "This tool call will get the weather."
|
||||
convo = Conversation.from_messages([
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'{"location": "Tokyo"}').with_channel("commentary").with_recipient(
|
||||
"functions.get_current_weather").with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT,
|
||||
final_content).with_channel("final"),
|
||||
])
|
||||
convo = Conversation.from_messages(
|
||||
[
|
||||
Message.from_role_and_content(
|
||||
Role.USER, "What is the weather in Tokyo based on where I'm at?"
|
||||
),
|
||||
Message.from_role_and_content(
|
||||
Role.ASSISTANT,
|
||||
'User asks: "What is the weather in Tokyo?" based on their location. We need to use get_current_weather tool and get_user_location tool.', # noqa: E501
|
||||
).with_channel("analysis"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}')
|
||||
.with_channel("commentary")
|
||||
.with_recipient("functions.get_current_weather")
|
||||
.with_content_type("json"),
|
||||
Message.from_role_and_content(Role.ASSISTANT, final_content).with_channel(
|
||||
"final"
|
||||
),
|
||||
]
|
||||
)
|
||||
token_ids = harmony_encoding.render_conversation_for_completion(
|
||||
convo,
|
||||
Role.ASSISTANT,
|
||||
@@ -210,10 +252,12 @@ def test_extract_tool_calls_with_content(
|
||||
)
|
||||
assert extracted_info.tools_called
|
||||
expected_tool_calls = [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({"location": "Tokyo"}),
|
||||
)
|
||||
),
|
||||
]
|
||||
assert_tool_calls(extracted_info.tool_calls, expected_tool_calls)
|
||||
assert extracted_info.content == final_content
|
||||
|
||||
@@ -7,9 +7,13 @@ from typing import Optional
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
|
||||
WEATHER_TOOL, ServerConfig)
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
ServerConfig,
|
||||
)
|
||||
|
||||
|
||||
# test: getting the model to generate parallel tool calls (streaming/not)
|
||||
@@ -17,12 +21,15 @@ from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
|
||||
# may be added in the future. e.g. llama 3.1 models are not designed to support
|
||||
# parallel tool calls.
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
|
||||
async def test_parallel_tool_calls(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip("The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
@@ -32,7 +39,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
@@ -69,7 +77,8 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
@@ -80,24 +89,22 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
tool_call_id_count: int = 0
|
||||
|
||||
async for chunk in stream:
|
||||
|
||||
# if there's a finish reason make sure it's tools
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
streamed_tool_calls = chunk.choices[0].delta.tool_calls
|
||||
|
||||
if streamed_tool_calls and len(streamed_tool_calls) > 0:
|
||||
|
||||
# make sure only one diff is present - correct even for parallel
|
||||
assert len(streamed_tool_calls) == 1
|
||||
tool_call = streamed_tool_calls[0]
|
||||
@@ -110,8 +117,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
# if a tool call ID is streamed, make sure one hasn't been already
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
assert (isinstance(tool_call.id, str)
|
||||
and (len(tool_call.id) >= 9))
|
||||
assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
@@ -125,32 +131,32 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
# make sure they're a string and then add them to the list
|
||||
assert isinstance(tool_call.function.arguments, str)
|
||||
|
||||
tool_call_args[
|
||||
tool_call.index] += tool_call.function.arguments
|
||||
tool_call_args[tool_call.index] += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert role_name == "assistant"
|
||||
|
||||
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
|
||||
len(tool_call_args))
|
||||
assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
|
||||
|
||||
for i in range(2):
|
||||
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
|
||||
streamed_args = json.loads(tool_call_args[i])
|
||||
non_streamed_args = json.loads(
|
||||
non_streamed_tool_calls[i].function.arguments)
|
||||
non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
|
||||
assert streamed_args == non_streamed_args
|
||||
|
||||
|
||||
# test: providing parallel tool calls back to the model to get a response
|
||||
# (streaming/not)
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
server_config: ServerConfig):
|
||||
|
||||
async def test_parallel_tool_calls_with_results(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
if not server_config.get("supports_parallel", True):
|
||||
pytest.skip("The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]))
|
||||
pytest.skip(
|
||||
"The {} model doesn't support parallel tool calls".format(
|
||||
server_config["model"]
|
||||
)
|
||||
)
|
||||
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
@@ -160,14 +166,14 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
max_completion_tokens=200,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # Dallas temp in tool response
|
||||
assert "78" in choice.message.content # Orlando temp in tool response
|
||||
@@ -179,7 +185,8 @@ async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
|
||||
@@ -7,14 +7,17 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import (
|
||||
Qwen3CoderToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import (
|
||||
Qwen3XMLToolParser)
|
||||
Qwen3CoderToolParser,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.qwen3xml_tool_parser import Qwen3XMLToolParser
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
@@ -39,8 +42,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
|
||||
|
||||
|
||||
@pytest.fixture(params=["original", "xml"])
|
||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
|
||||
request):
|
||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
|
||||
"""Parameterized fixture that provides both parser types for testing"""
|
||||
if request.param == "original":
|
||||
return qwen3_tool_parser
|
||||
@@ -51,76 +53,63 @@ def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser,
|
||||
@pytest.fixture
|
||||
def sample_tools():
|
||||
return [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city name"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The state code"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum":
|
||||
["fahrenheit", "celsius"]
|
||||
}
|
||||
},
|
||||
"required": ["city", "state"]
|
||||
}
|
||||
}),
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "calculate_area",
|
||||
"description":
|
||||
"Calculate area of a shape",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shape": {
|
||||
"type": "string"
|
||||
},
|
||||
"dimensions": {
|
||||
"type": "object"
|
||||
},
|
||||
"precision": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string", "description": "The city name"},
|
||||
"state": {"type": "string", "description": "The state code"},
|
||||
"unit": {"type": "string", "enum": ["fahrenheit", "celsius"]},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
},
|
||||
},
|
||||
),
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "calculate_area",
|
||||
"description": "Calculate area of a shape",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"shape": {"type": "string"},
|
||||
"dimensions": {"type": "object"},
|
||||
"precision": {"type": "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
# Qwen3 parser doesn't generate IDs during extraction
|
||||
assert actual_tool_call.type == "function"
|
||||
assert (
|
||||
actual_tool_call.function.name == expected_tool_call.function.name)
|
||||
assert (json.loads(actual_tool_call.function.arguments) == json.loads(
|
||||
expected_tool_call.function.arguments))
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert json.loads(actual_tool_call.function.arguments) == json.loads(
|
||||
expected_tool_call.function.arguments
|
||||
)
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
qwen3_tool_parser,
|
||||
qwen3_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None
|
||||
request: Optional[ChatCompletionRequest] = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = qwen3_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
all_token_ids = qwen3_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
@@ -129,18 +118,19 @@ def stream_delta_message_generator(
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=qwen3_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=qwen3_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
@@ -157,8 +147,9 @@ def stream_delta_message_generator(
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
@@ -166,7 +157,8 @@ def stream_delta_message_generator(
|
||||
def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -182,7 +174,8 @@ def test_extract_tool_calls_no_tools(qwen3_tool_parser_parametrized):
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
('''<tool_call>
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -194,16 +187,21 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
], None),
|
||||
('''Sure! Let me check the weather for you.<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Sure! Let me check the weather for you.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -215,16 +213,21 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
], "Sure! Let me check the weather for you."),
|
||||
('''<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Sure! Let me check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
rectangle
|
||||
@@ -237,18 +240,25 @@ rectangle
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(function=FunctionCall(name="calculate_area",
|
||||
arguments=json.dumps({
|
||||
"shape": "rectangle",
|
||||
"dimensions": {
|
||||
"width": 10,
|
||||
"height": 20
|
||||
},
|
||||
"precision": 2
|
||||
})))
|
||||
], None),
|
||||
('''<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "rectangle",
|
||||
"dimensions": {"width": 10, "height": 20},
|
||||
"precision": 2,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -273,23 +283,29 @@ FL
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
], None),
|
||||
('''Let me calculate that area for you.<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Let me calculate that area for you.<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
@@ -301,26 +317,36 @@ circle
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(function=FunctionCall(name="calculate_area",
|
||||
arguments=json.dumps({
|
||||
"shape": "circle",
|
||||
"dimensions": {
|
||||
"radius": 15.5
|
||||
},
|
||||
"precision": 3
|
||||
})))
|
||||
], "Let me calculate that area for you."),
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "circle",
|
||||
"dimensions": {"radius": 15.5},
|
||||
"precision": 3,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Let me calculate that area for you.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
def test_extract_tool_calls(
|
||||
qwen3_tool_parser_parametrized,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
model_output, request=request
|
||||
)
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -328,60 +354,51 @@ def test_extract_tool_calls(qwen3_tool_parser_parametrized, sample_tools,
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser_parametrized,
|
||||
sample_tools):
|
||||
def test_extract_tool_calls_fallback_no_tags(
|
||||
qwen3_tool_parser_parametrized, sample_tools
|
||||
):
|
||||
"""Test fallback parsing when XML tags are missing"""
|
||||
model_output = '''<function=get_current_weather>
|
||||
model_output = """<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
</function>'''
|
||||
</function>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert (extracted_tool_calls.tool_calls[0].function.name ==
|
||||
"get_current_weather")
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
def test_extract_tool_calls_type_conversion(qwen3_tool_parser_parametrized):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {
|
||||
"type": "integer"
|
||||
},
|
||||
"float_param": {
|
||||
"type": "float"
|
||||
},
|
||||
"bool_param": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"str_param": {
|
||||
"type": "string"
|
||||
},
|
||||
"obj_param": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {"type": "integer"},
|
||||
"float_param": {"type": "float"},
|
||||
"bool_param": {"type": "boolean"},
|
||||
"str_param": {"type": "string"},
|
||||
"obj_param": {"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
model_output = '''<tool_call>
|
||||
model_output = """<tool_call>
|
||||
<function=test_types>
|
||||
<parameter=int_param>
|
||||
42
|
||||
@@ -399,11 +416,12 @@ hello world
|
||||
{"key": "value"}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["int_param"] == 42
|
||||
@@ -425,7 +443,8 @@ hello world
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("This is a test without tools", [], "This is a test without tools"),
|
||||
('''<tool_call>
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -437,16 +456,21 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
], None),
|
||||
('''Sure! Let me check the weather for you.<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""Sure! Let me check the weather for you.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -458,16 +482,21 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
})))
|
||||
], "Sure! Let me check the weather for you."),
|
||||
('''<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Sure! Let me check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
rectangle
|
||||
@@ -480,18 +509,25 @@ rectangle
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(function=FunctionCall(name="calculate_area",
|
||||
arguments=json.dumps({
|
||||
"shape": "rectangle",
|
||||
"dimensions": {
|
||||
"width": 10,
|
||||
"height": 20
|
||||
},
|
||||
"precision": 2
|
||||
})))
|
||||
], None),
|
||||
('''<tool_call>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "rectangle",
|
||||
"dimensions": {"width": 10, "height": 20},
|
||||
"precision": 2,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -516,24 +552,30 @@ FL
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit"
|
||||
}))),
|
||||
ToolCall(
|
||||
function=FunctionCall(name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "celsius"
|
||||
})))
|
||||
], None),
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{"city": "Orlando", "state": "FL", "unit": "celsius"}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Added tool_with_typed_params test case
|
||||
('''Let me calculate that area for you.<tool_call>
|
||||
(
|
||||
"""Let me calculate that area for you.<tool_call>
|
||||
<function=calculate_area>
|
||||
<parameter=shape>
|
||||
circle
|
||||
@@ -545,33 +587,42 @@ circle
|
||||
3
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>''', [
|
||||
ToolCall(function=FunctionCall(name="calculate_area",
|
||||
arguments=json.dumps({
|
||||
"shape": "circle",
|
||||
"dimensions": {
|
||||
"radius": 15.5
|
||||
},
|
||||
"precision": 3
|
||||
})))
|
||||
], "Let me calculate that area for you."),
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="calculate_area",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"shape": "circle",
|
||||
"dimensions": {"radius": 15.5},
|
||||
"precision": 3,
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"Let me calculate that area for you.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
|
||||
qwen3_tokenizer, sample_tools,
|
||||
model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
def test_extract_tool_calls_streaming(
|
||||
qwen3_tool_parser_parametrized,
|
||||
qwen3_tokenizer,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Test incremental streaming behavior including typed parameters"""
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ''
|
||||
other_content = ""
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
@@ -588,7 +639,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None
|
||||
"type": None,
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
@@ -607,8 +658,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx][
|
||||
"arguments"] += tool_call.function.arguments
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == (expected_content or "") # Handle None case
|
||||
@@ -632,10 +682,11 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser_parametrized,
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_closing_parameter_tag(
|
||||
qwen3_tool_parser_parametrized, sample_tools):
|
||||
qwen3_tool_parser_parametrized, sample_tools
|
||||
):
|
||||
"""Test handling of missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = '''Let me check the weather for you:
|
||||
model_output = """Let me check the weather for you:
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
@@ -647,21 +698,19 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
# The parser should handle the malformed XML gracefully
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
|
||||
# Verify the function name is correct
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "get_current_weather"
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
# Verify the arguments are parsed despite the missing closing tag
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
@@ -675,10 +724,11 @@ fahrenheit
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_missing_closing_tag(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test streaming with missing closing </parameter> tag"""
|
||||
# Using get_current_weather from sample_tools but with malformed XML
|
||||
model_output = '''Let me check the weather for you:
|
||||
model_output = """Let me check the weather for you:
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
@@ -690,19 +740,16 @@ TX
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ''
|
||||
other_content = ""
|
||||
tool_states = {}
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
@@ -715,7 +762,7 @@ fahrenheit
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None
|
||||
"type": None,
|
||||
}
|
||||
|
||||
if tool_call.id:
|
||||
@@ -730,8 +777,7 @@ fahrenheit
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx][
|
||||
"arguments"] += tool_call.function.arguments
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify content was streamed
|
||||
assert "Let me check the weather for you:" in other_content
|
||||
@@ -752,9 +798,10 @@ fahrenheit
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test that streaming is truly incremental"""
|
||||
model_output = '''I'll check the weather.<tool_call>
|
||||
model_output = """I'll check the weather.<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
@@ -763,16 +810,14 @@ Dallas
|
||||
TX
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output,
|
||||
request):
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
@@ -787,7 +832,7 @@ TX
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert (chunk.tool_calls[0].function.name == "get_current_weather")
|
||||
assert chunk.tool_calls[0].function.name == "get_current_weather"
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Empty initially
|
||||
assert chunk.tool_calls[0].function.arguments == ""
|
||||
@@ -811,46 +856,40 @@ TX
|
||||
|
||||
|
||||
def test_extract_tool_calls_complex_type_with_single_quote(
|
||||
qwen3_tool_parser_parametrized):
|
||||
qwen3_tool_parser_parametrized,
|
||||
):
|
||||
"""Test parameter type conversion based on tool schema"""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {
|
||||
"type": "integer"
|
||||
},
|
||||
"float_param": {
|
||||
"type": "float"
|
||||
},
|
||||
"bool_param": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"str_param": {
|
||||
"type": "string"
|
||||
},
|
||||
"obj_param": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "test_types",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"int_param": {"type": "integer"},
|
||||
"float_param": {"type": "float"},
|
||||
"bool_param": {"type": "boolean"},
|
||||
"str_param": {"type": "string"},
|
||||
"obj_param": {"type": "object"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
model_output = '''<tool_call>
|
||||
model_output = """<tool_call>
|
||||
<function=test_types>
|
||||
<parameter=obj_param>
|
||||
{'key': 'value'}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>'''
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
extracted_tool_calls = qwen3_tool_parser_parametrized.extract_tool_calls(
|
||||
model_output, request=request)
|
||||
model_output, request=request
|
||||
)
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["obj_param"] == {"key": "value"}
|
||||
|
||||
@@ -8,10 +8,13 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
@@ -45,51 +48,56 @@ def sample_tools():
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"City and country e.g. Bogotá, Colombia"
|
||||
"description": "City and country e.g. Bogotá, Colombia",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "this is the unit of temperature"
|
||||
}
|
||||
"description": "this is the unit of temperature",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"returns": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "temperature in celsius"
|
||||
"description": "temperature in celsius",
|
||||
}
|
||||
},
|
||||
"required": ["temperature"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True
|
||||
}),
|
||||
"strict": True,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
# Seed-OSS tool call will not generate id
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
assert (
|
||||
actual_tool_call.function.arguments == expected_tool_call.function.arguments
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
model_output = "This is a test response without any tool calls"
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
@@ -104,17 +112,24 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
], None),
|
||||
(
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
@@ -131,13 +146,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
@@ -169,15 +188,18 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
@@ -196,13 +218,17 @@ def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
def test_extract_tool_calls(
|
||||
seed_oss_tool_parser,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
|
||||
model_output, request=request) # type: ignore[arg-type]
|
||||
model_output, request=request
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -225,7 +251,7 @@ def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert hasattr(result, "content")
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
@@ -233,10 +259,9 @@ def stream_delta_message_generator(
|
||||
seed_oss_tool_parser: SeedOssToolParser,
|
||||
seed_oss_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None
|
||||
request: Optional[ChatCompletionRequest] = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = seed_oss_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
all_token_ids = seed_oss_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
@@ -245,18 +270,19 @@ def stream_delta_message_generator(
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=seed_oss_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=seed_oss_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
@@ -273,8 +299,9 @@ def stream_delta_message_generator(
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
@@ -287,22 +314,27 @@ def stream_delta_message_generator(
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
),
|
||||
(
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
|
||||
"""<seed:tool_call>\n<function=get_weather>\n"""
|
||||
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
|
||||
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n""",
|
||||
),
|
||||
(
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
@@ -319,13 +351,17 @@ def stream_delta_message_generator(
|
||||
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
|
||||
"""\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"location": "Barcelona, Spain",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
|
||||
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
|
||||
@@ -357,15 +393,18 @@ def stream_delta_message_generator(
|
||||
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
|
||||
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
}, ),
|
||||
),
|
||||
type='function')
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"location": "Barcelona, Spain",
|
||||
"unit": "celsius",
|
||||
},
|
||||
),
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
|
||||
"""First, I need to remember the function I can use: get_weather. The function requires a """
|
||||
@@ -384,19 +423,23 @@ def stream_delta_message_generator(
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||
sample_tools, model_output, expected_tool_calls,
|
||||
expected_content):
|
||||
def test_streaming_tool_calls(
|
||||
seed_oss_tool_parser,
|
||||
seed_oss_tokenizer,
|
||||
sample_tools,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Test incremental streaming behavior"""
|
||||
request = ChatCompletionRequest(model=MODEL,
|
||||
messages=[],
|
||||
tools=sample_tools)
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ''
|
||||
other_content = ""
|
||||
tool_states = {} # Track state per tool index
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
|
||||
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request
|
||||
):
|
||||
# role should never be streamed from tool parser
|
||||
assert not delta_message.role
|
||||
|
||||
@@ -413,7 +456,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None
|
||||
"type": None,
|
||||
}
|
||||
|
||||
# First chunk should have id, name, and type
|
||||
@@ -432,8 +475,7 @@ def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
# Accumulate arguments incrementally
|
||||
tool_states[idx][
|
||||
"arguments"] += tool_call.function.arguments
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify final content
|
||||
assert other_content == expected_content
|
||||
|
||||
@@ -7,8 +7,12 @@ from typing import Optional
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL, WEATHER_TOOL)
|
||||
from .utils import (
|
||||
MESSAGES_ASKING_FOR_TOOLS,
|
||||
MESSAGES_WITH_TOOL_RESPONSE,
|
||||
SEARCH_TOOL,
|
||||
WEATHER_TOOL,
|
||||
)
|
||||
|
||||
|
||||
# test: request a chat completion that should return tool calls, so we know they
|
||||
@@ -23,17 +27,18 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
stop_reason = chat_completion.choices[0].finish_reason
|
||||
tool_calls = chat_completion.choices[0].message.tool_calls
|
||||
|
||||
# make sure a tool call is present
|
||||
assert choice.message.role == 'assistant'
|
||||
assert choice.message.role == "assistant"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].type == 'function'
|
||||
assert tool_calls[0].type == "function"
|
||||
assert tool_calls[0].function is not None
|
||||
assert isinstance(tool_calls[0].id, str)
|
||||
assert len(tool_calls[0].id) >= 9
|
||||
@@ -54,7 +59,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
assert stop_reason == "tool_calls"
|
||||
|
||||
function_name: Optional[str] = None
|
||||
function_args_str: str = ''
|
||||
function_args_str: str = ""
|
||||
tool_call_id: Optional[str] = None
|
||||
role_name: Optional[str] = None
|
||||
finish_reason_count: int = 0
|
||||
@@ -67,20 +72,21 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.choices[0].index == 0
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason_count += 1
|
||||
assert chunk.choices[0].finish_reason == 'tool_calls'
|
||||
assert chunk.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
# if a role is being streamed make sure it wasn't already set to
|
||||
# something else
|
||||
if chunk.choices[0].delta.role:
|
||||
assert not role_name or role_name == 'assistant'
|
||||
role_name = 'assistant'
|
||||
assert not role_name or role_name == "assistant"
|
||||
role_name = "assistant"
|
||||
|
||||
# if a tool call is streamed make sure there's exactly one
|
||||
# (based on the request parameters
|
||||
@@ -108,7 +114,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
function_args_str += tool_call.function.arguments
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert role_name == "assistant"
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
|
||||
|
||||
# validate the name and arguments
|
||||
@@ -148,14 +154,14 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False)
|
||||
logprobs=False,
|
||||
)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
|
||||
assert choice.message.content is not None
|
||||
assert "98" in choice.message.content # the temperature from the response
|
||||
|
||||
@@ -166,7 +172,8 @@ async def test_tool_call_with_results(client: openai.AsyncOpenAI):
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL, SEARCH_TOOL],
|
||||
logprobs=False,
|
||||
stream=True)
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks: list[str] = []
|
||||
finish_reason_count = 0
|
||||
|
||||
@@ -8,8 +8,10 @@ import pytest
|
||||
import regex as re
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionToolsParam)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
@@ -24,18 +26,16 @@ EXAMPLE_TOOLS = [
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for"
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for"
|
||||
", e.g. 'San Francisco'",
|
||||
},
|
||||
},
|
||||
"required": ["city"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
"strict": True,
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
@@ -46,35 +46,33 @@ EXAMPLE_TOOLS = [
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'New York'",
|
||||
"type": "string",
|
||||
"description": "The city to get the forecast for, e.g. 'New York'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
"type": "integer",
|
||||
"description": "Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
},
|
||||
"required": ["city", "days"],
|
||||
"additionalProperties": False
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"strict": True
|
||||
"strict": True,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
should_match: bool):
|
||||
def _compile_and_check(
|
||||
tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
|
||||
):
|
||||
self = MagicMock(tool_choice="required", tools=tools)
|
||||
schema = ChatCompletionRequest._get_json_schema_from_tool(self)
|
||||
assert isinstance(schema, dict)
|
||||
|
||||
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
|
||||
from outlines_core.json_schema import build_regex_from_schema
|
||||
|
||||
regex = build_regex_from_schema(json.dumps(schema))
|
||||
compiled = re.compile(regex)
|
||||
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
|
||||
@@ -83,65 +81,31 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
|
||||
|
||||
|
||||
VALID_TOOL_OUTPUTS = [
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}], True),
|
||||
([{
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Berlin",
|
||||
"days": 7
|
||||
}
|
||||
}, {
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Berlin"
|
||||
}
|
||||
}], True),
|
||||
([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True),
|
||||
(
|
||||
[
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True),
|
||||
(
|
||||
[
|
||||
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Vienna"}},
|
||||
{"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}},
|
||||
{"name": "get_current_weather", "parameters": {"city": "Berlin"}},
|
||||
],
|
||||
True,
|
||||
),
|
||||
]
|
||||
|
||||
VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
@@ -149,92 +113,100 @@ VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample_output, should_match",
|
||||
VALID_TOOL_OUTPUTS + [
|
||||
VALID_TOOL_OUTPUTS
|
||||
+ [
|
||||
(None, False),
|
||||
([], False), # empty list cannot be generated
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
False),
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None,
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
{ # tool call without lists cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
"parameters": {"city": "Vienna"},
|
||||
},
|
||||
False),
|
||||
False,
|
||||
),
|
||||
(
|
||||
[{ # tool call with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"extra": "value"
|
||||
[
|
||||
{ # tool call with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {"city": "Vienna", "extra": "value"},
|
||||
}
|
||||
}],
|
||||
False),
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[{ # tool call where parameters are first cannot be generated
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
},
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # tool call without all required parameters cannot be generated
|
||||
"name": "get_forecast",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
[
|
||||
{ # tool call where parameters are first cannot be generated
|
||||
"parameters": {"city": "Vienna"},
|
||||
"name": "get_current_weather",
|
||||
}
|
||||
}],
|
||||
False),
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # tool call without all required parameters cannot be generated
|
||||
"name": "get_forecast",
|
||||
"parameters": {"city": "Vienna"},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
( # tool call with incorrect name/parameters cannot be generated
|
||||
[{
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna",
|
||||
"days": 7
|
||||
}
|
||||
}], False),
|
||||
[{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}],
|
||||
False,
|
||||
),
|
||||
( # tool call with both valid and empty function cannot be generated
|
||||
[{
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"city": "Vienna"
|
||||
}
|
||||
}, {}], False),
|
||||
])
|
||||
[{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}],
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_structured_outputs_json(sample_output, should_match):
|
||||
_compile_and_check(tools=TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(EXAMPLE_TOOLS),
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
_compile_and_check(
|
||||
tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python(
|
||||
EXAMPLE_TOOLS
|
||||
),
|
||||
sample_output=sample_output,
|
||||
should_match=should_match,
|
||||
)
|
||||
|
||||
|
||||
def update_parameters_none(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = None
|
||||
return tool
|
||||
|
||||
|
||||
def update_parameters_empty_dict(
|
||||
tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
|
||||
tool: ChatCompletionToolsParam,
|
||||
) -> ChatCompletionToolsParam:
|
||||
tool.function.parameters = {}
|
||||
return tool
|
||||
|
||||
@@ -247,48 +219,60 @@ def update_parameters_empty_dict(
|
||||
({}, False), # empty object cannot be generated
|
||||
([{}], False), # list with empty object cannot be generated
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None
|
||||
}],
|
||||
False),
|
||||
(
|
||||
[{ # function with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {
|
||||
"extra": "value"
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather"
|
||||
}
|
||||
}],
|
||||
False),
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[{ # only function with empty parameters object is valid
|
||||
"name": "get_current_weather",
|
||||
"parameters": {}
|
||||
}],
|
||||
True),
|
||||
])
|
||||
[
|
||||
{ # function without required parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": None,
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # function with extra parameters cannot be generated
|
||||
"name": "get_current_weather",
|
||||
"parameters": {"extra": "value"},
|
||||
}
|
||||
],
|
||||
False,
|
||||
),
|
||||
(
|
||||
[
|
||||
{ # only function with empty parameters object is valid
|
||||
"name": "get_current_weather",
|
||||
"parameters": {},
|
||||
}
|
||||
],
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"update_parameters",
|
||||
[update_parameters_none, update_parameters_empty_dict])
|
||||
def test_structured_outputs_json_without_parameters(sample_output,
|
||||
should_match,
|
||||
update_parameters):
|
||||
"update_parameters", [update_parameters_none, update_parameters_empty_dict]
|
||||
)
|
||||
def test_structured_outputs_json_without_parameters(
|
||||
sample_output, should_match, update_parameters
|
||||
):
|
||||
updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
|
||||
tools = TypeAdapter(
|
||||
list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools)
|
||||
tools = list(map(update_parameters, tools))
|
||||
assert all([
|
||||
tool.function.parameters is None or tool.function.parameters == {}
|
||||
for tool in tools
|
||||
])
|
||||
_compile_and_check(tools=tools,
|
||||
sample_output=sample_output,
|
||||
should_match=should_match)
|
||||
assert all(
|
||||
[
|
||||
tool.function.parameters is None or tool.function.parameters == {}
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
_compile_and_check(
|
||||
tools=tools, sample_output=sample_output, should_match=should_match
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("output", VALID_TOOLS)
|
||||
@@ -306,7 +290,7 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
function_name_returned = False
|
||||
messages = []
|
||||
for i in range(0, len(output_json), delta_len):
|
||||
delta_text = output_json[i:i + delta_len]
|
||||
delta_text = output_json[i : i + delta_len]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message, function_name_returned = (
|
||||
@@ -315,7 +299,9 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=function_name_returned))
|
||||
function_name_returned=function_name_returned,
|
||||
)
|
||||
)
|
||||
|
||||
if delta_message:
|
||||
messages.append(delta_message)
|
||||
@@ -329,10 +315,12 @@ def test_streaming_output_valid(output, empty_params, delta_len):
|
||||
if len(combined_messages) > 1:
|
||||
combined_messages += "},"
|
||||
|
||||
combined_messages += '{"name": "' + \
|
||||
message.tool_calls[0].function.name + \
|
||||
'", "parameters": ' + \
|
||||
message.tool_calls[0].function.arguments
|
||||
combined_messages += (
|
||||
'{"name": "'
|
||||
+ message.tool_calls[0].function.name
|
||||
+ '", "parameters": '
|
||||
+ message.tool_calls[0].function.arguments
|
||||
)
|
||||
else:
|
||||
combined_messages += message.tool_calls[0].function.arguments
|
||||
combined_messages += "}]"
|
||||
|
||||
@@ -7,9 +7,12 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
@@ -30,12 +33,14 @@ def xlam_tool_parser(xlam_tokenizer):
|
||||
return xLAMToolParser(xlam_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
def assert_tool_calls(
|
||||
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
|
||||
):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
for actual_tool_call, expected_tool_call in zip(
|
||||
actual_tool_calls, expected_tool_calls
|
||||
):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 16
|
||||
|
||||
@@ -49,8 +54,7 @@ def stream_delta_message_generator(
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = xlam_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
@@ -59,18 +63,19 @@ def stream_delta_message_generator(
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
current_token_ids = all_token_ids[: i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = (detokenize_incrementally(
|
||||
tokenizer=xlam_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
))
|
||||
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
|
||||
detokenize_incrementally(
|
||||
tokenizer=xlam_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
@@ -87,8 +92,9 @@ def stream_delta_message_generator(
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
previous_tokens = (
|
||||
previous_tokens + new_tokens if previous_tokens else new_tokens
|
||||
)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
@@ -96,7 +102,8 @@ def stream_delta_message_generator(
|
||||
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -115,87 +122,113 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you with that.",
|
||||
),
|
||||
(
|
||||
"""I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(xlam_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
def test_extract_tool_calls(
|
||||
xlam_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -210,25 +243,30 @@ def test_extract_tool_calls(xlam_tool_parser, model_output,
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
|
||||
expected_tool_calls,
|
||||
expected_content):
|
||||
def test_extract_tool_calls_list_structure(
|
||||
xlam_tool_parser, model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
"""Test extraction of tool calls when the model outputs a list-structured tool call.""" # noqa: E501
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
model_output, request=None
|
||||
) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
@@ -239,20 +277,25 @@ def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
|
||||
# Test for preprocess_model_output method
|
||||
def test_preprocess_model_output(xlam_tool_parser):
|
||||
# Test with list structure
|
||||
model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
model_output = (
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
)
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output)
|
||||
model_output
|
||||
)
|
||||
assert content is None
|
||||
assert potential_tool_calls == model_output
|
||||
|
||||
# Test with thinking tag
|
||||
model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output)
|
||||
model_output
|
||||
)
|
||||
assert content == "<think>I'll help you with that.</think>"
|
||||
assert (
|
||||
potential_tool_calls ==
|
||||
'[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]')
|
||||
potential_tool_calls
|
||||
== '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
|
||||
)
|
||||
|
||||
# Test with JSON code block
|
||||
model_output = """I'll help you with that.
|
||||
@@ -260,14 +303,16 @@ def test_preprocess_model_output(xlam_tool_parser):
|
||||
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
|
||||
```"""
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output)
|
||||
model_output
|
||||
)
|
||||
assert content == "I'll help you with that."
|
||||
assert "get_current_weather" in potential_tool_calls
|
||||
|
||||
# Test with no tool calls
|
||||
model_output = """I'll help you with that."""
|
||||
content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
|
||||
model_output)
|
||||
model_output
|
||||
)
|
||||
assert content == model_output
|
||||
assert potential_tool_calls is None
|
||||
|
||||
@@ -281,7 +326,9 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
xlam_tool_parser.current_tool_id = -1
|
||||
|
||||
# Simulate receiving a message with list structure
|
||||
current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
current_text = (
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]""" # noqa: E501
|
||||
)
|
||||
|
||||
# First call to set up the tool
|
||||
xlam_tool_parser.extract_tool_calls_streaming(
|
||||
@@ -295,8 +342,7 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
)
|
||||
|
||||
# Make sure the tool is set up correctly
|
||||
assert (xlam_tool_parser.current_tool_id
|
||||
>= 0), "Tool index should be initialized"
|
||||
assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
|
||||
|
||||
# Manually set up the state for sending the tool name
|
||||
xlam_tool_parser.current_tools_sent = [False]
|
||||
@@ -332,78 +378,102 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps(
|
||||
{
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"I can help with that.",
|
||||
),
|
||||
@@ -421,7 +491,8 @@ def test_extract_tool_calls_streaming_incremental(
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
xlam_tool_parser, xlam_tokenizer, model_output, request):
|
||||
xlam_tool_parser, xlam_tokenizer, model_output, request
|
||||
):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
@@ -433,8 +504,9 @@ def test_extract_tool_calls_streaming_incremental(
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert (chunk.tool_calls[0].function.name ==
|
||||
expected_first_tool.function.name)
|
||||
assert (
|
||||
chunk.tool_calls[0].function.name == expected_first_tool.function.name
|
||||
)
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Arguments may be empty initially or None
|
||||
if chunk.tool_calls[0].function.arguments is not None:
|
||||
@@ -446,11 +518,13 @@ def test_extract_tool_calls_streaming_incremental(
|
||||
# Should have chunks with incremental arguments
|
||||
arg_chunks = []
|
||||
for chunk in chunks:
|
||||
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
|
||||
and chunk.tool_calls[0].function.arguments != ""
|
||||
and chunk.tool_calls[0].index ==
|
||||
0 # Only collect arguments from the first tool call
|
||||
):
|
||||
if (
|
||||
chunk.tool_calls
|
||||
and chunk.tool_calls[0].function.arguments
|
||||
and chunk.tool_calls[0].function.arguments != ""
|
||||
and chunk.tool_calls[0].index
|
||||
== 0 # Only collect arguments from the first tool call
|
||||
):
|
||||
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||
|
||||
# Arguments should be streamed incrementally
|
||||
|
||||
@@ -4,8 +4,7 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
from openai.types.chat import (ChatCompletionMessageParam,
|
||||
ChatCompletionToolParam)
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from tests.utils import VLLM_PATH
|
||||
@@ -20,8 +19,9 @@ class ServerConfig(TypedDict, total=False):
|
||||
extended: Optional[bool] # tests do not run in CI automatically
|
||||
|
||||
|
||||
def patch_system_prompt(messages: list[dict[str, Any]],
|
||||
system_prompt: str) -> list[dict[str, Any]]:
|
||||
def patch_system_prompt(
|
||||
messages: list[dict[str, Any]], system_prompt: str
|
||||
) -> list[dict[str, Any]]:
|
||||
new_messages = deepcopy(messages)
|
||||
if new_messages[0]["role"] == "system":
|
||||
new_messages[0]["content"] = system_prompt
|
||||
@@ -30,8 +30,9 @@ def patch_system_prompt(messages: list[dict[str, Any]],
|
||||
return new_messages
|
||||
|
||||
|
||||
def ensure_system_prompt(messages: list[dict[str, Any]],
|
||||
config: ServerConfig) -> list[dict[str, Any]]:
|
||||
def ensure_system_prompt(
|
||||
messages: list[dict[str, Any]], config: ServerConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = config.get("system_prompt")
|
||||
if prompt:
|
||||
return patch_system_prompt(messages, prompt)
|
||||
@@ -42,92 +43,102 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
|
||||
# universal args for all models go here. also good if you need to test locally
|
||||
# and change type or KV cache quantization or something.
|
||||
ARGS: list[str] = [
|
||||
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
|
||||
"256"
|
||||
"--enable-auto-tool-choice",
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--max-num-seqs",
|
||||
"256",
|
||||
]
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"hermes": {
|
||||
"model":
|
||||
"NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"model": "NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "hermes", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja"),
|
||||
],
|
||||
"system_prompt":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
"to the user's question - just respond to it normally.",
|
||||
},
|
||||
"llama": {
|
||||
"model":
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "llama3_json", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama3_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.1_json.jinja"),
|
||||
],
|
||||
"supports_parallel":
|
||||
False,
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"llama3.2": {
|
||||
"model":
|
||||
"meta-llama/Llama-3.2-3B-Instruct",
|
||||
"model": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "llama3_json", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama3_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama3.2_json.jinja"),
|
||||
],
|
||||
"supports_parallel":
|
||||
False,
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"llama4": {
|
||||
"model":
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "llama4_pythonic", "--chat-template",
|
||||
str(VLLM_PATH /
|
||||
"examples/tool_chat_template_llama4_pythonic.jinja"), "-tp",
|
||||
"4"
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"llama4_pythonic",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"),
|
||||
"-tp",
|
||||
"4",
|
||||
],
|
||||
"supports_parallel":
|
||||
False,
|
||||
"extended":
|
||||
True
|
||||
"supports_parallel": False,
|
||||
"extended": True,
|
||||
},
|
||||
"llama4_json": {
|
||||
"model":
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"model": "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching", "-tp", "4",
|
||||
"--distributed-executor-backend", "mp", "--tool-call-parser",
|
||||
"llama4_json", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"-tp",
|
||||
"4",
|
||||
"--distributed-executor-backend",
|
||||
"mp",
|
||||
"--tool-call-parser",
|
||||
"llama4_json",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_llama4_json.jinja"),
|
||||
],
|
||||
"supports_parallel":
|
||||
True,
|
||||
"extended":
|
||||
True
|
||||
"supports_parallel": True,
|
||||
"extended": True,
|
||||
},
|
||||
"mistral": {
|
||||
"model":
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"model": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "mistral", "--chat-template",
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"mistral",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"),
|
||||
"--ignore-patterns=\"consolidated.safetensors\""
|
||||
'--ignore-patterns="consolidated.safetensors"',
|
||||
],
|
||||
"system_prompt":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
"system_prompt": "You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
"to the user's question - just respond to it normally.",
|
||||
},
|
||||
# V1 Test: Passing locally but failing in CI. This runs the
|
||||
# V0 Engine because of CPU offloading. Need to debug why.
|
||||
@@ -146,49 +157,50 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
# False,
|
||||
# },
|
||||
"granite-3.0-8b": {
|
||||
"model":
|
||||
"ibm-granite/granite-3.0-8b-instruct",
|
||||
"model": "ibm-granite/granite-3.0-8b-instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "granite", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"granite",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_granite.jinja"),
|
||||
],
|
||||
},
|
||||
"granite-3.1-8b": {
|
||||
"model":
|
||||
"ibm-granite/granite-3.1-8b-instruct",
|
||||
"model": "ibm-granite/granite-3.1-8b-instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"granite",
|
||||
],
|
||||
"supports_parallel":
|
||||
True,
|
||||
"supports_parallel": True,
|
||||
},
|
||||
"internlm": {
|
||||
"model":
|
||||
"internlm/internlm2_5-7b-chat",
|
||||
"model": "internlm/internlm2_5-7b-chat",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "internlm", "--chat-template",
|
||||
str(VLLM_PATH /
|
||||
"examples/tool_chat_template_internlm2_tool.jinja"),
|
||||
"--trust_remote_code"
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"internlm",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_internlm2_tool.jinja"),
|
||||
"--trust_remote_code",
|
||||
],
|
||||
"supports_parallel":
|
||||
False,
|
||||
"supports_parallel": False,
|
||||
},
|
||||
"toolACE": {
|
||||
"model":
|
||||
"Team-ACE/ToolACE-8B",
|
||||
"model": "Team-ACE/ToolACE-8B",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "pythonic", "--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja")
|
||||
"--enforce-eager",
|
||||
"--no-enable-prefix-caching",
|
||||
"--tool-call-parser",
|
||||
"pythonic",
|
||||
"--chat-template",
|
||||
str(VLLM_PATH / "examples/tool_chat_template_toolace.jinja"),
|
||||
],
|
||||
"supports_parallel":
|
||||
True,
|
||||
"supports_parallel": True,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -201,37 +213,31 @@ WEATHER_TOOL: ChatCompletionToolParam = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city to find the weather for, "
|
||||
"e.g. 'San Francisco'"
|
||||
"type": "string",
|
||||
"description": "The city to find the weather for, "
|
||||
"e.g. 'San Francisco'",
|
||||
},
|
||||
"state": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"must the two-letter abbreviation for the state "
|
||||
"type": "string",
|
||||
"description": "must the two-letter abbreviation for the state "
|
||||
"that the city is in, e.g. 'CA' which would "
|
||||
"mean 'California'"
|
||||
"mean 'California'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_TOOL: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"web_search",
|
||||
"description":
|
||||
"Search the internet and get a summary of the top "
|
||||
"name": "web_search",
|
||||
"description": "Search the internet and get a summary of the top "
|
||||
"10 webpages. Should only be used if you don't know "
|
||||
"the answer to a user query, and the results are likely"
|
||||
"to be able to be found with a web search",
|
||||
@@ -239,124 +245,98 @@ SEARCH_TOOL: ChatCompletionToolParam = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search_term": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The term to use in the search. This should"
|
||||
"type": "string",
|
||||
"description": "The term to use in the search. This should"
|
||||
"ideally be keywords to search for, not a"
|
||||
"natural-language question"
|
||||
"natural-language question",
|
||||
}
|
||||
},
|
||||
"required": ["search_term"]
|
||||
}
|
||||
}
|
||||
"required": ["search_term"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Hi! How are you?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"I'm doing great! How can I assist you?"
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me a joke please?"
|
||||
}]
|
||||
MESSAGES_WITHOUT_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "Hi! How are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great! How can I assist you?"},
|
||||
{"role": "user", "content": "Can you tell me a joke please?"},
|
||||
]
|
||||
|
||||
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas in Fahrenheit?"
|
||||
}]
|
||||
MESSAGES_ASKING_FOR_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"}
|
||||
]
|
||||
|
||||
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas in Fahrenheit?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"tool_calls": [{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}]
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content":
|
||||
"The weather in Dallas is 98 degrees fahrenheit, with partly"
|
||||
"cloudy skies and a low chance of rain."
|
||||
}]
|
||||
MESSAGES_WITH_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, Texas in Fahrenheit?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content": "The weather in Dallas is 98 degrees fahrenheit, with partly"
|
||||
"cloudy skies and a low chance of rain.",
|
||||
},
|
||||
]
|
||||
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?"
|
||||
}]
|
||||
MESSAGES_ASKING_FOR_PARALLEL_TOOLS: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?",
|
||||
}
|
||||
]
|
||||
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"tool_calls": [{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}, {
|
||||
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
WEATHER_TOOL["function"]["name"],
|
||||
"arguments":
|
||||
'{"city": "Orlando", "state": "Fl", '
|
||||
'"unit": "fahrenheit"}'
|
||||
}
|
||||
}]
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content":
|
||||
"The weather in Dallas TX is 98 degrees fahrenheit with mostly "
|
||||
"cloudy skies and a chance of rain in the evening."
|
||||
}, {
|
||||
"role":
|
||||
"tool",
|
||||
"tool_call_id":
|
||||
"chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"content":
|
||||
"The weather in Orlando FL is 78 degrees fahrenheit with clear"
|
||||
"skies."
|
||||
}]
|
||||
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: list[ChatCompletionMessageParam] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, Texas and Orlando, Florida in "
|
||||
"Fahrenheit?",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Dallas", "state": "TX", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": WEATHER_TOOL["function"]["name"],
|
||||
"arguments": '{"city": "Orlando", "state": "Fl", '
|
||||
'"unit": "fahrenheit"}',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295",
|
||||
"content": "The weather in Dallas TX is 98 degrees fahrenheit with mostly "
|
||||
"cloudy skies and a chance of rain in the evening.",
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b",
|
||||
"content": "The weather in Orlando FL is 78 degrees fahrenheit with clear"
|
||||
"skies.",
|
||||
},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user