[Misc] Add 20 regression tests for 11 tool parser bug fixes (#38172)
Signed-off-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -1431,3 +1431,140 @@ rectangle
|
||||
assert "<function=calculate_area>" not in extracted_tool_calls.content, (
|
||||
"Second tool call should not be in content"
|
||||
)
|
||||
|
||||
|
||||
def _accumulate_tool_states(delta_messages):
|
||||
"""Accumulate tool call state from a stream of DeltaMessage objects."""
|
||||
content = ""
|
||||
tool_states = {}
|
||||
for delta_message in delta_messages:
|
||||
if delta_message.content:
|
||||
content += delta_message.content
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
if tool_call.type:
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
return content, tool_states
|
||||
|
||||
|
||||
def test_streaming_mtp_variable_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, sample_tools
|
||||
):
|
||||
"""Regression: MTP variable-size chunks spanning param boundaries (PR #33690)."""
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
delta_text_chunks = [
|
||||
"<tool_call>\n<function=get_current_weather>\n<parameter=city>\n",
|
||||
"Dallas\n</parameter>\n<parameter=state>\nTX",
|
||||
"\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>",
|
||||
"\n</function>\n</tool_call>",
|
||||
]
|
||||
|
||||
_, tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator_from_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, delta_text_chunks, request
|
||||
)
|
||||
)
|
||||
|
||||
assert len(tool_states) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == "get_current_weather"
|
||||
|
||||
args = json.loads(state["arguments"])
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
||||
|
||||
def test_streaming_multi_token_per_step(
|
||||
step3p5_tool_parser, step3p5_tokenizer, sample_tools
|
||||
):
|
||||
"""Regression: MTP large chunks spanning multiple tool calls (PR #33690)."""
|
||||
model_output = """<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Orlando
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
FL
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
# MTP-style large chunks
|
||||
mtp_chunks = [
|
||||
(
|
||||
"<tool_call>\n<function=get_current_weather>\n"
|
||||
"<parameter=city>\nDallas\n</parameter>\n"
|
||||
"<parameter=state>\nTX"
|
||||
),
|
||||
(
|
||||
"\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>\n"
|
||||
"</function>\n</tool_call>\n"
|
||||
"<tool_call>\n<function=get_current_weather>\n"
|
||||
"<parameter=city>\nOrlando\n</parameter>\n"
|
||||
"<parameter=state>\nFL\n</parameter>\n"
|
||||
"<parameter=unit>\ncelsius\n</parameter>\n"
|
||||
"</function>\n</tool_call>"
|
||||
),
|
||||
]
|
||||
|
||||
_, mtp_tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator_from_chunks(
|
||||
step3p5_tool_parser, step3p5_tokenizer, mtp_chunks, request
|
||||
)
|
||||
)
|
||||
|
||||
# Token-by-token streaming (reference)
|
||||
step3p5_tool_parser_ref = Step3p5ToolParser(step3p5_tokenizer)
|
||||
_, ref_tool_states = _accumulate_tool_states(
|
||||
stream_delta_message_generator(
|
||||
step3p5_tool_parser_ref, step3p5_tokenizer, model_output, request
|
||||
)
|
||||
)
|
||||
|
||||
assert len(mtp_tool_states) == 2
|
||||
assert len(ref_tool_states) == 2
|
||||
|
||||
# MTP results must match reference
|
||||
for idx in range(2):
|
||||
assert mtp_tool_states[idx]["name"] == ref_tool_states[idx]["name"]
|
||||
mtp_args = json.loads(mtp_tool_states[idx]["arguments"])
|
||||
ref_args = json.loads(ref_tool_states[idx]["arguments"])
|
||||
assert mtp_args == ref_args
|
||||
|
||||
Reference in New Issue
Block a user