Fix edge case Mistral tool parser (#30724)
Signed-off-by: Joachim Studnia <joachim@mistral.ai> Signed-off-by: Joachim Studnia <studniajoachim@gmail.com> Signed-off-by: juliendenize <julien.denize@mistral.ai> Signed-off-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: juliendenize <julien.denize@mistral.ai> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -281,6 +281,8 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
"single_tool_add",
|
||||
"single_tool_weather",
|
||||
"multiple_tool_calls",
|
||||
"complex",
|
||||
"wrong_json",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
@@ -326,6 +328,36 @@ def test_extract_tool_calls_pre_v11_tokenizer(
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
)[:-2],
|
||||
)
|
||||
)
|
||||
],
|
||||
"hi{hi",
|
||||
),
|
||||
(
|
||||
# Wrong json
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments=json.dumps(
|
||||
{"command": "print(\"hello world!\")\nre.compile(r'{}')"}
|
||||
),
|
||||
)
|
||||
)
|
||||
],
|
||||
"hi{hi",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(
|
||||
@@ -673,7 +705,7 @@ def test_extract_tool_calls_streaming(
|
||||
),
|
||||
(
|
||||
# Complex
|
||||
"""[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
"""hi{hi[TOOL_CALLS]bash{"command": "print(\\"hello world!\\")\\nre.compile(r\'{}\')"}""", # noqa: E501
|
||||
[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
@@ -684,7 +716,7 @@ def test_extract_tool_calls_streaming(
|
||||
)
|
||||
)
|
||||
],
|
||||
"",
|
||||
"hi{hi",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user