[BUGFIX] Mistral tool call parser v11+ (#30332)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -99,12 +99,7 @@ class MistralToolParser(ToolParser):
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if not _is_pre_v11_tokeniser(self.model_tokenizer):
|
||||
self.fn_name_regex = re.compile(
|
||||
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
|
||||
)
|
||||
else:
|
||||
self.fn_name_regex = None
|
||||
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
|
||||
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
@@ -148,23 +143,24 @@ class MistralToolParser(ToolParser):
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
if self.fn_name_regex:
|
||||
if not self._is_pre_v11:
|
||||
function_call_arr = []
|
||||
for single_tool_content in model_output.split(self.bot_token):
|
||||
matches = self.fn_name_regex.findall(single_tool_content)
|
||||
if "{" not in single_tool_content:
|
||||
continue
|
||||
|
||||
for match in matches:
|
||||
fn_name = match[0]
|
||||
args = match[1]
|
||||
end_name = single_tool_content.find("{")
|
||||
fn_name, args = (
|
||||
single_tool_content[:end_name],
|
||||
single_tool_content[end_name:],
|
||||
)
|
||||
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
# fn_name is encoded outside serialized json dump
|
||||
# only arguments are serialized
|
||||
function_call_arr.append(
|
||||
{"name": fn_name, "arguments": json.loads(args)}
|
||||
)
|
||||
else:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
|
||||
Reference in New Issue
Block a user