[Bugfix] Multiple fixes to tool streaming with hermes and mistral (#10979)

Signed-off-by: cedonley <clayton@donley.io>
This commit is contained in:
Clayton
2024-12-11 17:10:12 -08:00
committed by GitHub
parent 4e11683368
commit 7439a8b5fc
3 changed files with 69 additions and 21 deletions

View File

@@ -91,7 +91,8 @@ class Hermes2ProToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"])))
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False)))
for function_call in raw_function_calls
]
@@ -139,13 +140,26 @@ class Hermes2ProToolParser(ToolParser):
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count):
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
logger.debug("Generating text content! skipping tool parsing.")
if delta_text != self.tool_call_end_token:
return DeltaMessage(content=delta_text)
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
@@ -184,15 +198,21 @@ class Hermes2ProToolParser(ToolParser):
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count > prev_tool_end_count):
and cur_tool_end_count >= prev_tool_end_count):
if (self.prev_tool_call_arr is None
or len(self.prev_tool_call_arr) == 0):
logger.debug(
"attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff
diff = json.dumps(
diff, ensure_ascii=False
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
if ('"}' not in delta_text):
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
@@ -221,10 +241,15 @@ class Hermes2ProToolParser(ToolParser):
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
return None
except json.decoder.JSONDecodeError:
logger.debug("unable to parse JSON")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if (current_tool_call is None):
return None
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
@@ -284,13 +309,17 @@ class Hermes2ProToolParser(ToolParser):
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments)
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
# get the location where previous args differ from current
args_delta_start_loc = cur_arguments_json.index(delta_text) \
+ len(delta_text)
if (delta_text not in cur_arguments_json[:-2]):
return None
args_delta_start_loc = cur_arguments_json[:-2]. \
rindex(delta_text) + \
len(delta_text)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]