[Misc] Add 20 regression tests for 11 tool parser bug fixes (#38172)
Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import run_tool_extraction_streaming
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
FunctionDefinition,
|
||||
@@ -26,6 +27,7 @@ from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser
|
||||
# tokenizer object to be truthy (the parser checks `if not self.model_tokenizer`).
|
||||
MOCK_TOKENIZER = MagicMock()
|
||||
MOCK_TOKENIZER.get_vocab.return_value = {}
|
||||
MOCK_TOKENIZER.tokenize.return_value = []
|
||||
|
||||
|
||||
def make_parser(tools=None) -> DeepSeekV32ToolParser:
|
||||
@@ -483,6 +485,88 @@ class TestExtractToolCallsStreaming:
|
||||
assert all(not d.tool_calls for d in deltas)
|
||||
|
||||
|
||||
class TestDelimiterPreservation:
|
||||
"""Regression: fast detokenization skipping DSML delimiters (PR #33964)."""
|
||||
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return make_parser()
|
||||
|
||||
def test_delimiter_preserved_fast_detokenization(self, parser):
|
||||
"""DSML delimiters as literal text must still be detected."""
|
||||
# Delimiters appear as regular text (fast detokenization scenario).
|
||||
model_output = (
|
||||
f"{FC_START}\n"
|
||||
f'{INV_START}get_weather">\n'
|
||||
f'{PARAM_START}location" string="true">Tokyo{PARAM_END}\n'
|
||||
f"{INV_END}\n"
|
||||
f"{FC_END}"
|
||||
)
|
||||
|
||||
# Non-streaming: parser must detect the tool call
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
assert json.loads(result.tool_calls[0].function.arguments) == {
|
||||
"location": "Tokyo"
|
||||
}
|
||||
|
||||
assert result.content is None
|
||||
|
||||
# With content prefix
|
||||
prefixed_output = "Here is the weather: " + model_output
|
||||
result2 = parser.extract_tool_calls(prefixed_output, None)
|
||||
assert result2.tools_called
|
||||
assert result2.content == "Here is the weather: "
|
||||
|
||||
def test_tool_detection_skip_special_tokens_false(self, parser):
|
||||
"""Regression: skip_special_tokens must be False when tools are enabled."""
|
||||
# adjust_request must set skip_special_tokens=False
|
||||
tool = make_tool_param(
|
||||
"search",
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
},
|
||||
},
|
||||
)
|
||||
request = make_request(tools=[tool])
|
||||
request.tool_choice = "auto"
|
||||
adjusted = parser.adjust_request(request)
|
||||
assert adjusted.skip_special_tokens is False
|
||||
|
||||
full_text = build_tool_call("search", {"query": "vllm documentation"})
|
||||
|
||||
# Non-streaming extraction
|
||||
non_stream_result = parser.extract_tool_calls(full_text, request)
|
||||
assert non_stream_result.tools_called
|
||||
assert len(non_stream_result.tool_calls) == 1
|
||||
assert non_stream_result.tool_calls[0].function.name == "search"
|
||||
ns_args = json.loads(non_stream_result.tool_calls[0].function.arguments)
|
||||
assert ns_args == {"query": "vllm documentation"}
|
||||
|
||||
# Streaming extraction: drive the parser line-by-line
|
||||
chunks: list[str] = []
|
||||
remaining = full_text
|
||||
while remaining:
|
||||
nl = remaining.find("\n")
|
||||
if nl == -1:
|
||||
chunks.append(remaining)
|
||||
break
|
||||
chunks.append(remaining[: nl + 1])
|
||||
remaining = remaining[nl + 1 :]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
parser, chunks, request, assert_one_tool_per_delta=False
|
||||
)
|
||||
assert len(reconstructor.tool_calls) == 1
|
||||
assert reconstructor.tool_calls[0].function.name == "search"
|
||||
streamed_args = json.loads(reconstructor.tool_calls[0].function.arguments)
|
||||
assert streamed_args == ns_args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseekv32_tokenizer():
|
||||
return get_tokenizer(tokenizer_name="deepseek-ai/DeepSeek-V3.2")
|
||||
|
||||
@@ -822,3 +822,108 @@ def test_extract_tool_calls_numeric_deserialization(glm4_moe_tool_parser, mock_r
|
||||
# Boolean should be deserialized as bool
|
||||
assert args["enabled"] is True
|
||||
assert isinstance(args["enabled"], bool)
|
||||
|
||||
|
||||
def test_zero_argument_tool_call(glm4_moe_tool_parser, mock_request):
|
||||
"""Regression: zero-argument tool call crash (PR #32321)."""
|
||||
model_output = """<tool_call>get_time
|
||||
</tool_call>"""
|
||||
|
||||
extracted = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=mock_request
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted.tools_called
|
||||
assert len(extracted.tool_calls) == 1
|
||||
assert extracted.tool_calls[0].function.name == "get_time"
|
||||
args = json.loads(extracted.tool_calls[0].function.arguments)
|
||||
assert args == {}
|
||||
|
||||
|
||||
def test_malformed_tool_call_no_regex_match(glm4_moe_tool_parser, mock_request):
|
||||
"""Regression: malformed tool_call with no regex match (PR #32321)."""
|
||||
model_output = "<tool_call> </tool_call>"
|
||||
|
||||
extracted = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=mock_request
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted.tools_called is False
|
||||
assert extracted.tool_calls == []
|
||||
|
||||
|
||||
def test_delimiter_preserved_transformers_5x(glm4_moe_tool_parser):
|
||||
"""Regression: adjust_request sets skip_special_tokens=False (PR #31622)."""
|
||||
# Tools enabled
|
||||
request_with_tools = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
) # type: ignore
|
||||
adjusted = glm4_moe_tool_parser.adjust_request(request_with_tools)
|
||||
assert adjusted.skip_special_tokens is False
|
||||
|
||||
# tool_choice="none"
|
||||
request_no_choice = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"city": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
tool_choice="none",
|
||||
) # type: ignore
|
||||
adjusted_none = glm4_moe_tool_parser.adjust_request(request_no_choice)
|
||||
assert adjusted_none.skip_special_tokens is True
|
||||
|
||||
# No tools at all
|
||||
request_no_tools = ChatCompletionRequest(
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
) # type: ignore
|
||||
adjusted_empty = glm4_moe_tool_parser.adjust_request(request_no_tools)
|
||||
assert adjusted_empty.skip_special_tokens is True
|
||||
|
||||
|
||||
def test_unicode_characters_preserved(glm4_moe_tool_parser, mock_request):
|
||||
"""Regression: Unicode chars must not be escaped to \\uXXXX (PR #30920)."""
|
||||
model_output = """<tool_call>send_message
|
||||
<arg_key>greeting</arg_key>
|
||||
<arg_value>你好世界</arg_value>
|
||||
<arg_key>emoji</arg_key>
|
||||
<arg_value>🎉</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=mock_request
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
assert extracted.tools_called
|
||||
assert len(extracted.tool_calls) == 1
|
||||
|
||||
raw_args = extracted.tool_calls[0].function.arguments
|
||||
assert "你好世界" in raw_args
|
||||
assert "🎉" in raw_args
|
||||
assert "\\u4f60" not in raw_args
|
||||
parsed_args = json.loads(raw_args)
|
||||
assert parsed_args["greeting"] == "你好世界"
|
||||
assert parsed_args["emoji"] == "🎉"
|
||||
|
||||
@@ -872,6 +872,59 @@ def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
|
||||
assert "I'll check the weather." in full_content or len(all_content) > 0
|
||||
|
||||
|
||||
def test_native_id_extracted_and_placed_on_tool_call(kimi_k2_tool_parser):
|
||||
"""Regression: parser extracts native ID onto ToolCall (PR #32768)."""
|
||||
model_output = (
|
||||
"Checking weather. "
|
||||
"<|tool_calls_section_begin|>"
|
||||
"<|tool_call_begin|>functions.get_weather:0"
|
||||
'<|tool_call_argument_begin|>{"city": "Tokyo"}'
|
||||
"<|tool_call_end|>"
|
||||
"<|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
result = kimi_k2_tool_parser.extract_tool_calls(model_output, request=None)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
|
||||
tc = result.tool_calls[0]
|
||||
# Native ID from model output must be used as the tool call ID
|
||||
assert tc.id == "functions.get_weather:0"
|
||||
assert tc.function.name == "get_weather"
|
||||
assert json.loads(tc.function.arguments) == {"city": "Tokyo"}
|
||||
|
||||
|
||||
def test_multi_turn_native_id_continuity(kimi_k2_tool_parser, kimi_k2_tokenizer):
|
||||
"""Regression: native IDs from turn 1 preserved across turns (PR #32768)."""
|
||||
turn1_output = (
|
||||
"Let me check. "
|
||||
"<|tool_calls_section_begin|>"
|
||||
"<|tool_call_begin|>functions.get_weather:0"
|
||||
'<|tool_call_argument_begin|>{"city": "Beijing"}'
|
||||
"<|tool_call_end|>"
|
||||
"<|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
turn1_result = kimi_k2_tool_parser.extract_tool_calls(turn1_output, request=None)
|
||||
assert turn1_result.tools_called
|
||||
assert turn1_result.tool_calls[0].id == "functions.get_weather:0"
|
||||
|
||||
# Fresh parser for turn 2
|
||||
turn2_parser = KimiK2ToolParser(kimi_k2_tokenizer)
|
||||
turn2_output = (
|
||||
"Now let me get news. "
|
||||
"<|tool_calls_section_begin|>"
|
||||
"<|tool_call_begin|>functions.get_news:0"
|
||||
'<|tool_call_argument_begin|>{"topic": "weather in Beijing"}'
|
||||
"<|tool_call_end|>"
|
||||
"<|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
turn2_result = turn2_parser.extract_tool_calls(turn2_output, request=None)
|
||||
assert turn2_result.tools_called
|
||||
assert turn2_result.tool_calls[0].id == "functions.get_news:0"
|
||||
|
||||
|
||||
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that MULTIPLE tool calls in streaming mode do not leak into content.
|
||||
|
||||
@@ -5,6 +5,10 @@ import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionToolsParam,
|
||||
FunctionDefinition,
|
||||
)
|
||||
from vllm.tool_parsers.minimax_m2_tool_parser import (
|
||||
MinimaxM2ToolParser,
|
||||
)
|
||||
@@ -442,3 +446,105 @@ class TestLargeChunks:
|
||||
"city": "Seattle",
|
||||
"days": "5",
|
||||
}
|
||||
|
||||
|
||||
class TestAnyOfNullableParam:
|
||||
"""Regression: anyOf nullable parameter parsing (PR #32342)."""
|
||||
|
||||
def test_anyof_nullable_param_non_null_value(self):
|
||||
"""A valid non-null string should be preserved, not collapsed to None."""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
function=FunctionDefinition(
|
||||
name="update_profile",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nickname": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
]
|
||||
parser = MinimaxM2ToolParser(FakeTokenizer(), tools=tools)
|
||||
|
||||
results = _feed(
|
||||
parser,
|
||||
[
|
||||
'<minimax:tool_call><invoke name="update_profile">'
|
||||
'<parameter name="nickname">Alice</parameter>'
|
||||
"</invoke></minimax:tool_call>",
|
||||
],
|
||||
)
|
||||
tc = _collect_tool_calls(results)
|
||||
assert len(tc) == 1
|
||||
parsed = json.loads(tc[0]["arguments"])
|
||||
assert parsed["nickname"] == "Alice"
|
||||
|
||||
def test_anyof_nullable_param_null_value(self):
|
||||
"""An actual null-like value should be returned as None/null."""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
function=FunctionDefinition(
|
||||
name="update_profile",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nickname": {
|
||||
"anyOf": [{"type": "string"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
]
|
||||
parser = MinimaxM2ToolParser(FakeTokenizer(), tools=tools)
|
||||
|
||||
results = _feed(
|
||||
parser,
|
||||
[
|
||||
'<minimax:tool_call><invoke name="update_profile">'
|
||||
'<parameter name="nickname">null</parameter>'
|
||||
"</invoke></minimax:tool_call>",
|
||||
],
|
||||
)
|
||||
tc = _collect_tool_calls(results)
|
||||
assert len(tc) == 1
|
||||
parsed = json.loads(tc[0]["arguments"])
|
||||
assert parsed["nickname"] is None
|
||||
|
||||
def test_anyof_nullable_param_object_value(self):
|
||||
"""A valid object value in anyOf with null should parse as dict."""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
function=FunctionDefinition(
|
||||
name="update_settings",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"anyOf": [{"type": "object"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
]
|
||||
parser = MinimaxM2ToolParser(FakeTokenizer(), tools=tools)
|
||||
|
||||
results = _feed(
|
||||
parser,
|
||||
[
|
||||
'<minimax:tool_call><invoke name="update_settings">'
|
||||
'<parameter name="config">{"theme": "dark", "fontSize": 14}'
|
||||
"</parameter>"
|
||||
"</invoke></minimax:tool_call>",
|
||||
],
|
||||
)
|
||||
tc = _collect_tool_calls(results)
|
||||
assert len(tc) == 1
|
||||
parsed = json.loads(tc[0]["arguments"])
|
||||
assert parsed["config"] == {"theme": "dark", "fontSize": 14}
|
||||
assert isinstance(parsed["config"], dict)
|
||||
|
||||
@@ -890,3 +890,64 @@ def test_extract_tool_calls_streaming_pre_v11_tokenizer_one_chunk(
|
||||
assert expected_content == ""
|
||||
else:
|
||||
assert delta_message.content == expected_content
|
||||
|
||||
|
||||
def test_fast_detokenization_text_detection(mistral_tool_parser):
|
||||
"""Regression: bot_token in text but not token_ids (PR #37209)."""
|
||||
model_output = '[TOOL_CALLS]add{"a": 1, "b": 2}'
|
||||
# Token IDs that do NOT contain bot_token_id.
|
||||
fake_token_ids = list(range(99, 99 + 20))
|
||||
|
||||
# First delta: pure content, no bot token yet
|
||||
delta_message_before = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="Hello",
|
||||
delta_text="Hello",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[99],
|
||||
delta_token_ids=[99],
|
||||
request=None,
|
||||
)
|
||||
assert delta_message_before is not None
|
||||
assert delta_message_before.content == "Hello"
|
||||
assert not delta_message_before.tool_calls
|
||||
|
||||
# Second delta: bot token in text but NOT in token_ids
|
||||
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Hello",
|
||||
current_text="Hello" + model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[99],
|
||||
current_token_ids=fake_token_ids,
|
||||
delta_token_ids=fake_token_ids[1:],
|
||||
request=None,
|
||||
)
|
||||
assert delta_message is not None
|
||||
assert delta_message.tool_calls is not None
|
||||
assert len(delta_message.tool_calls) > 0
|
||||
assert delta_message.tool_calls[0].function is not None
|
||||
assert delta_message.tool_calls[0].function.name == "add"
|
||||
|
||||
|
||||
def test_fast_detokenization_text_detection_pre_v11(
|
||||
mistral_pre_v11_tool_parser,
|
||||
):
|
||||
"""Regression: bot_token text detection for pre-v11 tokenizer (PR #37209)."""
|
||||
model_output = '[TOOL_CALLS] [{"name": "add", "arguments":{"a": 1, "b": 2}}]'
|
||||
|
||||
fake_token_ids = list(range(99, 99 + 30))
|
||||
|
||||
delta_message = mistral_pre_v11_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=model_output,
|
||||
delta_text=model_output,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=fake_token_ids,
|
||||
delta_token_ids=fake_token_ids,
|
||||
request=None,
|
||||
)
|
||||
assert delta_message is not None
|
||||
assert delta_message.tool_calls is not None
|
||||
assert len(delta_message.tool_calls) > 0
|
||||
assert delta_message.tool_calls[0].function is not None
|
||||
assert delta_message.tool_calls[0].function.name == "add"
|
||||
|
||||
@@ -974,3 +974,157 @@ fahrenheit
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_malformed_xml_no_gt_delimiter(qwen3_tool_parser, sample_tools):
|
||||
"""Regression: malformed XML without '>' must not crash (PR #36774)."""
|
||||
model_output = (
|
||||
"<tool_call>\n"
|
||||
"<function=get_current_weather\n"
|
||||
"<parameter=city>Dallas</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
result = qwen3_tool_parser.extract_tool_calls(model_output, request=request)
|
||||
assert result is not None
|
||||
assert isinstance(result.tool_calls, list)
|
||||
assert all(tc is not None for tc in result.tool_calls)
|
||||
|
||||
|
||||
def test_none_tool_calls_filtered(qwen3_tool_parser, sample_tools):
|
||||
"""Regression: None tool calls filtered from output (PR #36774)."""
|
||||
model_output = (
|
||||
"<tool_call>\n"
|
||||
"<function=bad_func_no_gt\n"
|
||||
"</function>\n"
|
||||
"</tool_call>\n"
|
||||
"<tool_call>\n"
|
||||
"<function=get_current_weather>\n"
|
||||
"<parameter=city>Dallas</parameter>\n"
|
||||
"<parameter=state>TX</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
result = qwen3_tool_parser.extract_tool_calls(model_output, request=request)
|
||||
assert all(tc is not None for tc in result.tool_calls)
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_current_weather"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
|
||||
|
||||
def test_anyof_parameter_not_double_encoded(qwen3_tokenizer):
|
||||
"""Regression: anyOf parameters must not be double-encoded (PR #36032)."""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "update_record",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"anyOf": [{"type": "object"}, {"type": "null"}],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
parser = Qwen3CoderToolParser(qwen3_tokenizer, tools=tools)
|
||||
|
||||
model_output = (
|
||||
"<tool_call>\n"
|
||||
"<function=update_record>\n"
|
||||
'<parameter=data>{"key": "value", "count": 42}</parameter>\n'
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
result = parser.extract_tool_calls(model_output, request=request)
|
||||
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert isinstance(args["data"], dict)
|
||||
assert args["data"] == {"key": "value", "count": 42}
|
||||
|
||||
|
||||
def test_streaming_multi_param_single_chunk(
|
||||
qwen3_tool_parser, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Regression: speculative decode delivering multiple params at once (PR #35615)."""
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
deltas = [
|
||||
"<tool_call>",
|
||||
"\n<function=get_current_weather>",
|
||||
"\n", # triggers json_started -> sends "{"
|
||||
# This single delta delivers all three parameters at once
|
||||
"<parameter=city>\nDallas\n</parameter>"
|
||||
"\n<parameter=state>\nTX\n</parameter>"
|
||||
"\n<parameter=unit>\nfahrenheit\n</parameter>",
|
||||
"\n</function>",
|
||||
"\n</tool_call>",
|
||||
]
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
qwen3_tool_parser,
|
||||
deltas,
|
||||
request,
|
||||
assert_one_tool_per_delta=False,
|
||||
)
|
||||
|
||||
assert len(reconstructor.tool_calls) == 1
|
||||
args = json.loads(reconstructor.tool_calls[0].function.arguments)
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_no_double_serialization_string_args(qwen3_tool_parser):
|
||||
"""Regression: string arguments must not be double-serialized (PR #35615)."""
|
||||
tools = [
|
||||
ChatCompletionToolsParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "greet",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
model_output = (
|
||||
"<tool_call>\n"
|
||||
"<function=greet>\n"
|
||||
"<parameter=message>hello world</parameter>\n"
|
||||
"</function>\n"
|
||||
"</tool_call>"
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools)
|
||||
result = qwen3_tool_parser.extract_tool_calls(model_output, request=request)
|
||||
|
||||
assert result.tools_called
|
||||
assert len(result.tool_calls) == 1
|
||||
raw_arguments = result.tool_calls[0].function.arguments
|
||||
args = json.loads(raw_arguments)
|
||||
assert args["message"] == "hello world"
|
||||
assert '\\"hello world\\"' not in raw_arguments
|
||||
|
||||
@@ -1431,3 +1431,140 @@ rectangle
|
||||
assert "<function=calculate_area>" not in extracted_tool_calls.content, (
|
||||
"Second tool call should not be in content"
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_tool_states(delta_messages):
|
||||
"""Accumulate tool call state from a stream of DeltaMessage objects."""
|
||||
content = ""
|
||||
tool_states = {}
|
||||
for delta_message in delta_messages:
|
||||
if delta_message.content:
|
||||
content += delta_message.content
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
if tool_call.type:
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
return content, tool_states
|
||||
|
||||
|
||||
def test_streaming_mtp_variable_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, sample_tools
|
||||
):
|
||||
"""Regression: MTP variable-size chunks spanning param boundaries (PR #33690)."""
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
delta_text_chunks = [
|
||||
"<tool_call>\n<function=get_current_weather>\n<parameter=city>\n",
|
||||
"Dallas\n</parameter>\n<parameter=state>\nTX",
|
||||
"\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>",
|
||||
"\n</function>\n</tool_call>",
|
||||
]
|
||||
|
||||
_, tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator_from_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request
|
||||
)
|
||||
)
|
||||
|
||||
assert len(tool_states) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == "get_current_weather"
|
||||
|
||||
args = json.loads(state["arguments"])
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_streaming_multi_token_per_step(
|
||||
step3p5_tool_parser, step3p5_tokenizer, sample_tools
|
||||
):
|
||||
"""Regression: MTP large chunks spanning multiple tool calls (PR #33690)."""
|
||||
model_output = """<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Orlando
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
FL
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
# MTP-style large chunks
|
||||
mtp_chunks = [
|
||||
(
|
||||
"<tool_call>\n<function=get_current_weather>\n"
|
||||
"<parameter=city>\nDallas\n</parameter>\n"
|
||||
"<parameter=state>\nTX"
|
||||
),
|
||||
(
|
||||
"\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>\n"
|
||||
"</function>\n</tool_call>\n"
|
||||
"<tool_call>\n<function=get_current_weather>\n"
|
||||
"<parameter=city>\nOrlando\n</parameter>\n"
|
||||
"<parameter=state>\nFL\n</parameter>\n"
|
||||
"<parameter=unit>\ncelsius\n</parameter>\n"
|
||||
"</function>\n</tool_call>"
|
||||
),
|
||||
]
|
||||
|
||||
_, mtp_tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator_from_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, mtp_chunks, request
|
||||
)
|
||||
)
|
||||
|
||||
# Token-by-token streaming (reference)
|
||||
step3p5_tool_parser_ref = Step3p5ToolParser(step3p5_tokenizer)
|
||||
_, ref_tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator(
|
||||
step3p5_tool_parser_ref, step3p5_tokenizer, model_output, request
|
||||
)
|
||||
)
|
||||
|
||||
assert len(mtp_tool_states) == 2
|
||||
assert len(ref_tool_states) == 2
|
||||
|
||||
# MTP results must match reference
|
||||
for idx in range(2):
|
||||
assert mtp_tool_states[idx]["name"] == ref_tool_states[idx]["name"]
|
||||
mtp_args = json.loads(mtp_tool_states[idx]["arguments"])
|
||||
ref_args = json.loads(ref_tool_states[idx]["arguments"])
|
||||
assert mtp_args == ref_args
|
||||
|
||||
Reference in New Issue
Block a user