[BUGFIX] Fix regex pattern for Mistral Tool Call (#29918)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
(cherry picked from commit 1b1e35aaf9)
This commit is contained in:
committed by
Kevin H. Luu
parent
9057fc2f1b
commit
6a6108511f
@@ -315,3 +315,38 @@ def test_mistral_function_call_nested_json():
|
|||||||
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
|
assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
|
||||||
# No additional content outside the tool call should be returned.
|
# No additional content outside the tool call should be returned.
|
||||||
assert parsed.content is None
|
assert parsed.content is None
|
||||||
|
|
||||||
|
# multiple calls
|
||||||
|
multiple_args_dict = [
|
||||||
|
{
|
||||||
|
"city": "Dallas",
|
||||||
|
"state": "TX",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
"sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}},
|
||||||
|
},
|
||||||
|
{},
|
||||||
|
{"a": 0},
|
||||||
|
{"a": 1, "b": "c"},
|
||||||
|
]
|
||||||
|
names = ["get_current_weather", "get_current_weather_2", "random", "random_2"]
|
||||||
|
|
||||||
|
model_output = "".join(
|
||||||
|
[
|
||||||
|
f"{parser.bot_token}{name}{json.dumps(args)}"
|
||||||
|
for name, args in zip(names, multiple_args_dict)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed = parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
# Assertions: the tool call is detected and the full nested JSON is parsed
|
||||||
|
# without truncation.
|
||||||
|
assert parsed.tools_called
|
||||||
|
assert len(parsed.tool_calls) == len(multiple_args_dict)
|
||||||
|
|
||||||
|
for i, tool_call in enumerate(parsed.tool_calls):
|
||||||
|
assert MistralToolCall.is_valid_id(tool_call.id)
|
||||||
|
assert tool_call.function.name == names[i]
|
||||||
|
assert json.loads(tool_call.function.arguments) == multiple_args_dict[i]
|
||||||
|
# No additional content outside the tool call should be returned.
|
||||||
|
assert parsed.content is None
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class MistralToolParser(ToolParser):
|
|||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
if _is_fn_name_regex_support(self.model_tokenizer):
|
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||||
self.fn_name_regex = re.compile(
|
self.fn_name_regex = re.compile(
|
||||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)?", re.DOTALL
|
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.fn_name_regex = None
|
self.fn_name_regex = None
|
||||||
|
|||||||
Reference in New Issue
Block a user