[Frontend] Added support for HermesToolParser for models without special tokens (#16890)
Signed-off-by: minpeter <kali2005611@gmail.com>
This commit is contained in:
@@ -52,14 +52,51 @@ class Hermes2ProToolParser(ToolParser):
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
if (self.tool_call_start_token_id is None
|
||||
or self.tool_call_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_start_token_ids
|
||||
]
|
||||
|
||||
self.tool_call_end_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
for token_id in self.tool_call_end_token_ids
|
||||
]
|
||||
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
# Very simple idea: when encountering tokens like <, tool, _call, >,
|
||||
# <, /, tool, _call, >, store them in a buffer.
|
||||
# When the last token is encountered, empty the buffer and return it.
|
||||
# If a token appears in an incorrect sequence while storing in the buffer,
|
||||
# return the preceding buffer along with the token.
|
||||
def tool_call_delta_buffer(self, delta_text: str):
|
||||
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
||||
# complete, fill the buffer with the token and return "".
|
||||
if (delta_text in self.tool_call_start_token_array
|
||||
or delta_text in self.tool_call_end_token_array):
|
||||
# If delta_text is the last token of tool_call_start_token or
|
||||
# tool_call_end_token, empty the buffer and return
|
||||
# the buffered text + delta_text.
|
||||
if (delta_text == self.tool_call_start_token_array[-1]
|
||||
or delta_text == self.tool_call_end_token_array[-1]):
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
self.buffered_delta_text = self.buffered_delta_text + delta_text
|
||||
return ""
|
||||
else:
|
||||
if self.buffered_delta_text:
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
else:
|
||||
return delta_text
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
@@ -124,11 +161,23 @@ class Hermes2ProToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
# 1. All tokens are parsed based on _text, not token_ids.
|
||||
# 2. All incoming text data is processed by the tool_call_delta_buffer
|
||||
# function for buffering before being used for parsing.
|
||||
|
||||
delta_text = self.tool_call_delta_buffer(delta_text)
|
||||
# If the last characters of previous_text
|
||||
# match self.buffered_delta_text, remove only the matching part.
|
||||
if (len(previous_text) >= len(self.buffered_delta_text)
|
||||
and previous_text[-len(self.buffered_delta_text):]
|
||||
== self.buffered_delta_text):
|
||||
previous_text = previous_text[:-len(self.buffered_delta_text)]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
if self.tool_call_start_token not in current_text:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
@@ -136,14 +185,12 @@ class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
prev_tool_start_count = previous_text.count(
|
||||
self.tool_call_start_token)
|
||||
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
||||
cur_tool_start_count = current_text.count(
|
||||
self.tool_call_start_token)
|
||||
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user