[Bugfix] Fix Hermes tool parser when stream interval > 1 (#38168)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@@ -152,6 +152,175 @@ def test_hermes_parser_streaming(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _simulate_streaming(
|
||||||
|
tokenizer: TokenizerLike,
|
||||||
|
parser: ToolParser,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
text: str,
|
||||||
|
stream_interval: int = 1,
|
||||||
|
) -> list:
|
||||||
|
"""Simulate streaming with a given stream_interval.
|
||||||
|
|
||||||
|
Tokens are batched into chunks of `stream_interval` tokens,
|
||||||
|
mimicking how the output processor delivers them.
|
||||||
|
Returns a list of non-None DeltaMessages.
|
||||||
|
"""
|
||||||
|
tokens = tokenizer.encode(text)
|
||||||
|
previous_text = ""
|
||||||
|
delta_messages = []
|
||||||
|
for i in range(0, len(tokens), stream_interval):
|
||||||
|
chunk_ids = tokens[i : i + stream_interval]
|
||||||
|
delta_text = tokenizer.decode(chunk_ids)
|
||||||
|
current_text = previous_text + delta_text
|
||||||
|
delta = parser.extract_tool_calls_streaming(
|
||||||
|
previous_text=previous_text,
|
||||||
|
current_text=current_text,
|
||||||
|
delta_text=delta_text,
|
||||||
|
previous_token_ids=[],
|
||||||
|
current_token_ids=[],
|
||||||
|
delta_token_ids=chunk_ids,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
previous_text = current_text
|
||||||
|
if delta is not None:
|
||||||
|
delta_messages.append(delta)
|
||||||
|
return delta_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8])
|
||||||
|
def test_hermes_streaming_tool_call_with_stream_interval(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
stream_interval: int,
|
||||||
|
) -> None:
|
||||||
|
"""Tool call streaming must produce correct name + args at any interval."""
|
||||||
|
text = (
|
||||||
|
'<tool_call>{"name": "get_current_temperature", '
|
||||||
|
'"arguments": {"location": "San Francisco", "unit": "celsius"}}'
|
||||||
|
"</tool_call>"
|
||||||
|
)
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer, parser, any_chat_request, text, stream_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten all DeltaToolCalls across all deltas.
|
||||||
|
tool_deltas = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
|
||||||
|
assert tool_deltas, "Expected at least one tool call delta"
|
||||||
|
assert tool_deltas[0].function.name == "get_current_temperature"
|
||||||
|
|
||||||
|
# Concatenated arguments must be valid JSON matching the original.
|
||||||
|
args_str = "".join(tc.function.arguments or "" for tc in tool_deltas)
|
||||||
|
assert json.loads(args_str) == {
|
||||||
|
"location": "San Francisco",
|
||||||
|
"unit": "celsius",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8])
|
||||||
|
def test_hermes_streaming_content_then_tool_call_with_stream_interval(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
stream_interval: int,
|
||||||
|
) -> None:
|
||||||
|
"""Content before a tool call must be fully streamed, then tool call."""
|
||||||
|
text = (
|
||||||
|
"Sure, let me check the weather."
|
||||||
|
'<tool_call>{"name": "get_weather", '
|
||||||
|
'"arguments": {"city": "NYC"}}</tool_call>'
|
||||||
|
)
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer, parser, any_chat_request, text, stream_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
content_deltas = [d for d in deltas if d.content]
|
||||||
|
tool_deltas = [d for d in deltas if d.tool_calls]
|
||||||
|
|
||||||
|
# Content must reconstruct the prefix.
|
||||||
|
content_str = "".join(d.content for d in content_deltas)
|
||||||
|
assert content_str == "Sure, let me check the weather."
|
||||||
|
|
||||||
|
# Tool call must be correct.
|
||||||
|
tool_calls = [tc for d in tool_deltas for tc in d.tool_calls]
|
||||||
|
assert tool_calls[0].function.name == "get_weather"
|
||||||
|
args_str = "".join(tc.function.arguments or "" for tc in tool_calls)
|
||||||
|
assert json.loads(args_str) == {"city": "NYC"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream_interval", [1, 2, 4])
|
||||||
|
def test_hermes_streaming_multiple_tool_calls_with_stream_interval(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
stream_interval: int,
|
||||||
|
) -> None:
|
||||||
|
"""Multiple sequential tool calls must each be streamed correctly."""
|
||||||
|
text = (
|
||||||
|
'<tool_call>{"name": "search", "arguments": {"q": "cats"}}</tool_call>'
|
||||||
|
'<tool_call>{"name": "search", "arguments": {"q": "dogs"}}</tool_call>'
|
||||||
|
)
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer, parser, any_chat_request, text, stream_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
# Flatten all DeltaToolCalls across all deltas.
|
||||||
|
all_tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
|
||||||
|
|
||||||
|
# Separate by tool index.
|
||||||
|
tool0 = [tc for tc in all_tool_calls if tc.index == 0]
|
||||||
|
tool1 = [tc for tc in all_tool_calls if tc.index == 1]
|
||||||
|
|
||||||
|
assert tool0[0].function.name == "search"
|
||||||
|
args0 = "".join(tc.function.arguments or "" for tc in tool0)
|
||||||
|
assert json.loads(args0) == {"q": "cats"}
|
||||||
|
|
||||||
|
assert tool1[0].function.name == "search"
|
||||||
|
args1 = "".join(tc.function.arguments or "" for tc in tool1)
|
||||||
|
assert json.loads(args1) == {"q": "dogs"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream_interval", [2, 5])
|
||||||
|
def test_hermes_streaming_boolean_args_with_stream_interval(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
stream_interval: int,
|
||||||
|
) -> None:
|
||||||
|
"""Regression test for bug #19056 with stream_interval > 1."""
|
||||||
|
text = (
|
||||||
|
"<tool_call>\n"
|
||||||
|
'{"name": "final_answer", "arguments": {"trigger": true}}\n'
|
||||||
|
"</tool_call>"
|
||||||
|
)
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer, parser, any_chat_request, text, stream_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
|
||||||
|
assert tool_calls[0].function.name == "final_answer"
|
||||||
|
args_str = "".join(tc.function.arguments or "" for tc in tool_calls)
|
||||||
|
assert json.loads(args_str) == {"trigger": True}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("stream_interval", [2, 3, 5])
|
||||||
|
def test_hermes_streaming_just_forward_text_with_stream_interval(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
stream_interval: int,
|
||||||
|
) -> None:
|
||||||
|
"""Plain text with no tool calls must be fully forwarded."""
|
||||||
|
text = "This is plain text with no tool calling involved."
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer, parser, any_chat_request, text, stream_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
for d in deltas:
|
||||||
|
assert not d.tool_calls
|
||||||
|
assert "".join(d.content for d in deltas) == text
|
||||||
|
|
||||||
|
|
||||||
def test_hermes_parser_non_streaming_no_tool_call(
|
def test_hermes_parser_non_streaming_no_tool_call(
|
||||||
hermes_parser: ToolParser,
|
hermes_parser: ToolParser,
|
||||||
any_chat_request: ChatCompletionRequest,
|
any_chat_request: ChatCompletionRequest,
|
||||||
@@ -218,3 +387,28 @@ def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
|||||||
|
|
||||||
assert tool_call is not None
|
assert tool_call is not None
|
||||||
assert not tool_call.tools_called
|
assert not tool_call.tools_called
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_streaming_content_and_tool_call_in_single_chunk(
|
||||||
|
qwen_tokenizer: TokenizerLike,
|
||||||
|
any_chat_request: ChatCompletionRequest,
|
||||||
|
) -> None:
|
||||||
|
"""Content + complete tool call in one chunk must both be emitted."""
|
||||||
|
text = 'Hi!<tool_call>{"name": "f", "arguments": {"x": 1}}</tool_call>'
|
||||||
|
# Use a stream_interval large enough to guarantee a single chunk.
|
||||||
|
parser = Hermes2ProToolParser(qwen_tokenizer)
|
||||||
|
deltas = _simulate_streaming(
|
||||||
|
qwen_tokenizer,
|
||||||
|
parser,
|
||||||
|
any_chat_request,
|
||||||
|
text,
|
||||||
|
stream_interval=9999,
|
||||||
|
)
|
||||||
|
|
||||||
|
content_parts = [d.content for d in deltas if d.content]
|
||||||
|
tool_parts = [tc for d in deltas if d.tool_calls for tc in d.tool_calls]
|
||||||
|
|
||||||
|
assert "".join(content_parts) == "Hi!"
|
||||||
|
assert tool_parts[0].function.name == "f"
|
||||||
|
args_str = "".join(tc.function.arguments or "" for tc in tool_parts)
|
||||||
|
assert json.loads(args_str) == {"x": 1}
|
||||||
|
|||||||
@@ -4,9 +4,7 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
import partial_json_parser
|
|
||||||
import regex as re
|
import regex as re
|
||||||
from partial_json_parser.core.options import Allow
|
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||||
@@ -31,6 +29,27 @@ from vllm.utils.mistral import is_mistral_tokenizer
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _partial_tag_overlap(text: str, tag: str) -> int:
|
||||||
|
"""Length of the longest prefix of `tag` that matches a suffix of `text`.
|
||||||
|
|
||||||
|
E.g. text ending in "<tool_" returns 6 when tag is "<tool_call>".
|
||||||
|
Returns 0 if there is no overlap.
|
||||||
|
"""
|
||||||
|
max_check = min(len(tag) - 1, len(text))
|
||||||
|
for k in range(max_check, 0, -1):
|
||||||
|
if text.endswith(tag[:k]):
|
||||||
|
return k
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _is_valid_json(text: str) -> bool:
|
||||||
|
try:
|
||||||
|
json.loads(text)
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class Hermes2ProToolParser(ToolParser):
|
class Hermes2ProToolParser(ToolParser):
|
||||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||||
super().__init__(tokenizer, tools)
|
super().__init__(tokenizer, tools)
|
||||||
@@ -39,13 +58,6 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||||
self.model_tokenizer = tokenizer.tokenizer
|
self.model_tokenizer = tokenizer.tokenizer
|
||||||
|
|
||||||
self.current_tool_name_sent: bool = False
|
|
||||||
self.prev_tool_call_arr: list[dict] = []
|
|
||||||
self.current_tool_id: int = -1
|
|
||||||
self.streamed_args_for_tool: list[
|
|
||||||
str
|
|
||||||
] = [] # map what has been streamed for each tool so far to a list
|
|
||||||
|
|
||||||
self.tool_call_start_token: str = "<tool_call>"
|
self.tool_call_start_token: str = "<tool_call>"
|
||||||
self.tool_call_end_token: str = "</tool_call>"
|
self.tool_call_end_token: str = "</tool_call>"
|
||||||
|
|
||||||
@@ -61,57 +73,9 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
"The model tokenizer must be passed to the ToolParser "
|
"The model tokenizer must be passed to the ToolParser "
|
||||||
"constructor during construction."
|
"constructor during construction."
|
||||||
)
|
)
|
||||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
|
||||||
self.tool_call_start_token, add_special_tokens=False
|
|
||||||
)
|
|
||||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
|
||||||
self.tool_call_end_token, add_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tool_call_start_token_array = [
|
# Streaming state: what has been sent to the client.
|
||||||
self.model_tokenizer.decode([token_id])
|
self._sent_content_idx: int = 0
|
||||||
for token_id in self.tool_call_start_token_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
self.tool_call_end_token_array = [
|
|
||||||
self.model_tokenizer.decode([token_id])
|
|
||||||
for token_id in self.tool_call_end_token_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
self.buffered_delta_text = ""
|
|
||||||
|
|
||||||
# Very simple idea: when encountering tokens like <, tool, _call, >,
|
|
||||||
# <, /, tool, _call, >, store them in a buffer.
|
|
||||||
# When the last token is encountered, empty the buffer and return it.
|
|
||||||
# If a token appears in an incorrect sequence while storing in the buffer,
|
|
||||||
# return the preceding buffer along with the token.
|
|
||||||
def tool_call_delta_buffer(self, delta_text: str):
|
|
||||||
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
|
||||||
# complete, fill the buffer with the token and return "".
|
|
||||||
if (
|
|
||||||
delta_text in self.tool_call_start_token_array
|
|
||||||
or delta_text in self.tool_call_end_token_array
|
|
||||||
):
|
|
||||||
# If delta_text is the last token of tool_call_start_token or
|
|
||||||
# tool_call_end_token, empty the buffer and return
|
|
||||||
# the buffered text + delta_text.
|
|
||||||
if (
|
|
||||||
delta_text == self.tool_call_start_token_array[-1]
|
|
||||||
or delta_text == self.tool_call_end_token_array[-1]
|
|
||||||
):
|
|
||||||
buffered_text = self.buffered_delta_text
|
|
||||||
self.buffered_delta_text = ""
|
|
||||||
return buffered_text + delta_text
|
|
||||||
else:
|
|
||||||
self.buffered_delta_text = self.buffered_delta_text + delta_text
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
if self.buffered_delta_text:
|
|
||||||
buffered_text = self.buffered_delta_text
|
|
||||||
self.buffered_delta_text = ""
|
|
||||||
return buffered_text + delta_text
|
|
||||||
else:
|
|
||||||
return delta_text
|
|
||||||
|
|
||||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||||
request = super().adjust_request(request)
|
request = super().adjust_request(request)
|
||||||
@@ -174,6 +138,88 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
tools_called=False, tool_calls=[], content=model_output
|
tools_called=False, tool_calls=[], content=model_output
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _extract_content(self, current_text: str) -> str | None:
|
||||||
|
"""Return unsent non-tool-call text, or None.
|
||||||
|
|
||||||
|
Holds back any suffix that could be a partial <tool_call> tag.
|
||||||
|
"""
|
||||||
|
if self.tool_call_start_token not in current_text:
|
||||||
|
overlap_length = _partial_tag_overlap(
|
||||||
|
current_text, self.tool_call_start_token
|
||||||
|
)
|
||||||
|
sendable_idx = len(current_text) - overlap_length
|
||||||
|
else:
|
||||||
|
sendable_idx = current_text.index(self.tool_call_start_token)
|
||||||
|
|
||||||
|
if sendable_idx > self._sent_content_idx:
|
||||||
|
content = current_text[self._sent_content_idx : sendable_idx]
|
||||||
|
self._sent_content_idx = sendable_idx
|
||||||
|
return content
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_tool_call_jsons(self, text: str) -> list[tuple[str, bool]]:
|
||||||
|
"""Extract (json_text, is_complete) for each <tool_call> region."""
|
||||||
|
results: list[tuple[str, bool]] = []
|
||||||
|
pos = 0
|
||||||
|
while True:
|
||||||
|
start = text.find(self.tool_call_start_token, pos)
|
||||||
|
if start == -1:
|
||||||
|
break
|
||||||
|
json_start = start + len(self.tool_call_start_token)
|
||||||
|
json_end = text.find(self.tool_call_end_token, json_start)
|
||||||
|
if json_end != -1:
|
||||||
|
results.append((text[json_start:json_end].strip(), True))
|
||||||
|
pos = json_end + len(self.tool_call_end_token)
|
||||||
|
else:
|
||||||
|
raw = text[json_start:]
|
||||||
|
# Strip partial </tool_call> suffix if present.
|
||||||
|
overlap = _partial_tag_overlap(raw, self.tool_call_end_token)
|
||||||
|
if overlap:
|
||||||
|
raw = raw[:-overlap]
|
||||||
|
tc_json = raw.strip()
|
||||||
|
# Valid JSON without closing tag = complete body,
|
||||||
|
# tag tokens just haven't arrived yet.
|
||||||
|
is_complete = _is_valid_json(tc_json) if tc_json else False
|
||||||
|
results.append((tc_json, is_complete))
|
||||||
|
break
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_name(tc_json: str) -> str | None:
|
||||||
|
"""Extract tool name, or None if the name isn't complete yet."""
|
||||||
|
match = re.search(r'"name"\s*:\s*"([^"]+)"', tc_json)
|
||||||
|
return match.group(1) if match else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tool_args(tc_json: str, is_complete: bool) -> str | None:
|
||||||
|
"""Extract tool arguments from the tool call JSON.
|
||||||
|
|
||||||
|
Given {"name": "f", "arguments": {"x": 1}}, returns '{"x": 1}'.
|
||||||
|
When is_complete, strips the trailing '}' that closes the outer
|
||||||
|
object (not the arguments). For partial JSON, returns as-is.
|
||||||
|
"""
|
||||||
|
match = re.search(r'"arguments"\s*:\s*', tc_json)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
raw = tc_json[match.end() :]
|
||||||
|
if is_complete:
|
||||||
|
raw = raw.rstrip()
|
||||||
|
if raw.endswith("}"):
|
||||||
|
raw = raw[:-1].rstrip()
|
||||||
|
return raw
|
||||||
|
|
||||||
|
def _compute_args_diff(
|
||||||
|
self, index: int, tc_json: str, is_complete: bool
|
||||||
|
) -> str | None:
|
||||||
|
"""Return new argument text not yet sent for tool `index`, or None."""
|
||||||
|
args = self._extract_tool_args(tc_json, is_complete)
|
||||||
|
if args is None or len(args) <= len(self.streamed_args_for_tool[index]):
|
||||||
|
return None
|
||||||
|
diff = args[len(self.streamed_args_for_tool[index]) :]
|
||||||
|
self.streamed_args_for_tool[index] = args
|
||||||
|
self.prev_tool_call_arr[index]["arguments"] = args
|
||||||
|
return diff
|
||||||
|
|
||||||
def extract_tool_calls_streaming(
|
def extract_tool_calls_streaming(
|
||||||
self,
|
self,
|
||||||
previous_text: str,
|
previous_text: str,
|
||||||
@@ -184,321 +230,64 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> DeltaMessage | None:
|
) -> DeltaMessage | None:
|
||||||
# 1. All tokens are parsed based on _text, not token_ids.
|
"""Incrementally stream tool call deltas from accumulated output.
|
||||||
# 2. All incoming text data is processed by the tool_call_delta_buffer
|
|
||||||
# function for buffering before being used for parsing.
|
|
||||||
|
|
||||||
delta_text = self.tool_call_delta_buffer(delta_text)
|
On each invocation, re-parses the full ``current_text`` to find
|
||||||
# If the last characters of previous_text
|
``<tool_call>`` regions, then diffs against previously sent state
|
||||||
# match self.buffered_delta_text, remove only the matching part.
|
to emit only new content, tool names, or argument fragments.
|
||||||
if (
|
|
||||||
len(previous_text) >= len(self.buffered_delta_text)
|
|
||||||
and previous_text[-len(self.buffered_delta_text) :]
|
|
||||||
== self.buffered_delta_text
|
|
||||||
):
|
|
||||||
previous_text = previous_text[: -len(self.buffered_delta_text)]
|
|
||||||
current_text = previous_text + delta_text
|
|
||||||
|
|
||||||
logger.debug("delta_text: %s", delta_text)
|
|
||||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
|
||||||
# check to see if we should be streaming a tool call - is there a
|
|
||||||
if self.tool_call_start_token not in current_text:
|
|
||||||
logger.debug("No tool call tokens found!")
|
|
||||||
return DeltaMessage(content=delta_text)
|
|
||||||
|
|
||||||
|
Returns a ``DeltaMessage`` containing either plain content (for
|
||||||
|
text preceding any tool call) or one or more ``DeltaToolCall``
|
||||||
|
entries, or ``None`` if there is nothing new to send yet."""
|
||||||
try:
|
try:
|
||||||
# figure out where we are in the parsing by counting tool call
|
# Extract any content before tool calls.
|
||||||
# start & end tags
|
content = self._extract_content(current_text)
|
||||||
prev_tool_start_count = previous_text.count(self.tool_call_start_token)
|
tool_call_jsons = self._extract_tool_call_jsons(current_text)
|
||||||
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
tool_call_deltas: list[DeltaToolCall] = []
|
||||||
cur_tool_start_count = current_text.count(self.tool_call_start_token)
|
|
||||||
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
|
||||||
tool_call_portion = None
|
|
||||||
text_portion = None
|
|
||||||
|
|
||||||
# case: if we're generating text, OR rounding out a tool call
|
for i, (tc_json, is_complete) in enumerate(tool_call_jsons):
|
||||||
if (
|
if i >= len(self.prev_tool_call_arr):
|
||||||
cur_tool_start_count == cur_tool_end_count
|
self.prev_tool_call_arr.append({})
|
||||||
and prev_tool_end_count == cur_tool_end_count
|
self.streamed_args_for_tool.append("")
|
||||||
and self.tool_call_end_token not in delta_text
|
|
||||||
):
|
|
||||||
logger.debug("Generating text content! skipping tool parsing.")
|
|
||||||
return DeltaMessage(content=delta_text)
|
|
||||||
|
|
||||||
if self.tool_call_end_token in delta_text:
|
# Stream back tool name.
|
||||||
logger.debug("tool_call_end_token in delta_text")
|
if "name" not in self.prev_tool_call_arr[i]:
|
||||||
full_text = current_text + delta_text
|
name = self._extract_tool_name(tc_json)
|
||||||
tool_call_portion = (
|
if not name:
|
||||||
full_text.split(self.tool_call_start_token)[-1]
|
# Can't skip to tool i+1 if i isn't ready
|
||||||
.split(self.tool_call_end_token)[0]
|
break
|
||||||
.rstrip()
|
self.prev_tool_call_arr[i]["name"] = name
|
||||||
)
|
tool_call_deltas.append(
|
||||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
|
||||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
|
||||||
|
|
||||||
# case: if tool open & close tag counts don't match, we're doing
|
|
||||||
# imaginary "else" block here
|
|
||||||
# something with tools with this diff.
|
|
||||||
# flags for partial JSON parting. exported constants from
|
|
||||||
# "Allow" are handled via BIT MASK
|
|
||||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
|
||||||
|
|
||||||
# case -- we're starting a new tool call
|
|
||||||
if (
|
|
||||||
cur_tool_start_count > cur_tool_end_count
|
|
||||||
and cur_tool_start_count > prev_tool_start_count
|
|
||||||
):
|
|
||||||
if len(delta_token_ids) > 1:
|
|
||||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
|
||||||
-1
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
tool_call_portion = None
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
text_portion = None
|
|
||||||
|
|
||||||
# set cursors and state appropriately
|
|
||||||
self.current_tool_id += 1
|
|
||||||
self.current_tool_name_sent = False
|
|
||||||
self.streamed_args_for_tool.append("")
|
|
||||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
|
||||||
|
|
||||||
# case -- we're updating an existing tool call
|
|
||||||
elif (
|
|
||||||
cur_tool_start_count > cur_tool_end_count
|
|
||||||
and cur_tool_start_count == prev_tool_start_count
|
|
||||||
):
|
|
||||||
# get the portion of the text that's the tool call
|
|
||||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
|
||||||
text_portion = None
|
|
||||||
|
|
||||||
# case -- the current tool call is being closed.
|
|
||||||
elif (
|
|
||||||
cur_tool_start_count == cur_tool_end_count
|
|
||||||
and cur_tool_end_count >= prev_tool_end_count
|
|
||||||
):
|
|
||||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
|
||||||
logger.debug("attempting to close tool call, but no tool call")
|
|
||||||
return None
|
|
||||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
|
||||||
if diff:
|
|
||||||
diff = (
|
|
||||||
diff.encode("utf-8").decode("unicode_escape")
|
|
||||||
if diff is str
|
|
||||||
else diff
|
|
||||||
)
|
|
||||||
if '"}' not in delta_text:
|
|
||||||
return None
|
|
||||||
end_loc = delta_text.rindex('"}')
|
|
||||||
diff = delta_text[:end_loc] + '"}'
|
|
||||||
logger.debug(
|
|
||||||
"Finishing tool and found diff that had not "
|
|
||||||
"been streamed yet: %s",
|
|
||||||
diff,
|
|
||||||
)
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
|
||||||
return DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_id,
|
|
||||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
|
||||||
exclude_none=True
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# case -- otherwise we're just generating text
|
|
||||||
else:
|
|
||||||
text = delta_text.replace(self.tool_call_start_token, "")
|
|
||||||
text = text.replace(self.tool_call_end_token, "")
|
|
||||||
delta = DeltaMessage(tool_calls=[], content=text)
|
|
||||||
return delta
|
|
||||||
|
|
||||||
try:
|
|
||||||
current_tool_call = (
|
|
||||||
partial_json_parser.loads(tool_call_portion or "{}", flags)
|
|
||||||
if tool_call_portion
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logger.debug("Parsed tool call %s", current_tool_call)
|
|
||||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
|
||||||
logger.debug("not enough tokens to parse into JSON yet")
|
|
||||||
return None
|
|
||||||
except json.decoder.JSONDecodeError:
|
|
||||||
logger.debug("unable to parse JSON")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if current_tool_call is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# case - we haven't sent the tool name yet. If it's available, send
|
|
||||||
# it. otherwise, wait until it's available.
|
|
||||||
if not self.current_tool_name_sent:
|
|
||||||
function_name: str | None = current_tool_call.get("name")
|
|
||||||
if function_name:
|
|
||||||
self.current_tool_name_sent = True
|
|
||||||
return DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_id,
|
|
||||||
type="function",
|
|
||||||
id=make_tool_call_id(),
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
name=function_name
|
|
||||||
).model_dump(exclude_none=True),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
# case -- otherwise, send the tool call delta
|
|
||||||
|
|
||||||
# if the tool call portion is None, send the delta as text
|
|
||||||
if tool_call_portion is None:
|
|
||||||
# if there's text but not tool calls, send that -
|
|
||||||
# otherwise None to skip chunk
|
|
||||||
delta = (
|
|
||||||
DeltaMessage(content=delta_text)
|
|
||||||
if text_portion is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return delta
|
|
||||||
|
|
||||||
# now, the nitty-gritty of tool calls
|
|
||||||
# now we have the portion to parse as tool call.
|
|
||||||
|
|
||||||
if current_tool_call is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# if we're starting a new tool call, push an empty object in as
|
|
||||||
# a placeholder for the arguments
|
|
||||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
|
||||||
self.prev_tool_call_arr.append({})
|
|
||||||
|
|
||||||
# main logic for tool parsing here - compare prev. partially-parsed
|
|
||||||
# JSON to the current partially-parsed JSON
|
|
||||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
|
||||||
"arguments"
|
|
||||||
)
|
|
||||||
assert current_tool_call is not None
|
|
||||||
cur_arguments = current_tool_call.get("arguments")
|
|
||||||
|
|
||||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
|
||||||
logger.debug("against new ones: %s", cur_arguments)
|
|
||||||
|
|
||||||
# case -- no arguments have been created yet. skip sending a delta.
|
|
||||||
if not cur_arguments and not prev_arguments:
|
|
||||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
# case -- prev arguments are defined, but non are now.
|
|
||||||
# probably impossible, but not a fatal error - just keep going
|
|
||||||
elif not cur_arguments and prev_arguments:
|
|
||||||
logger.error(
|
|
||||||
"should be impossible to have arguments reset "
|
|
||||||
"mid-call. skipping streaming anything."
|
|
||||||
)
|
|
||||||
delta = None
|
|
||||||
|
|
||||||
# case -- we now have the first info about arguments available from
|
|
||||||
# autocompleting the JSON
|
|
||||||
elif cur_arguments and not prev_arguments:
|
|
||||||
# extract the content after {"name": ..., "arguments":
|
|
||||||
# directly from tool_call_portion as cur_arguments_json,
|
|
||||||
# since cur_arguments may differ from the original text
|
|
||||||
# due to partial JSON parsing
|
|
||||||
# for example, tool_call_portion =
|
|
||||||
# {"name": "search", "arguments": {"search_request": {"
|
|
||||||
# but cur_arguments =
|
|
||||||
# {"search_request": {}}
|
|
||||||
function_name = current_tool_call.get("name")
|
|
||||||
match = re.search(
|
|
||||||
r'\{"name":\s*"'
|
|
||||||
+ re.escape(function_name)
|
|
||||||
+ r'"\s*,\s*"arguments":\s*(.*)',
|
|
||||||
tool_call_portion.strip(),
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
if match:
|
|
||||||
cur_arguments_json = match.group(1)
|
|
||||||
else:
|
|
||||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
|
||||||
|
|
||||||
logger.debug("finding %s in %s", delta_text, cur_arguments_json)
|
|
||||||
|
|
||||||
# get the location where previous args differ from current.
|
|
||||||
if delta_text not in cur_arguments_json:
|
|
||||||
return None
|
|
||||||
args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(
|
|
||||||
delta_text
|
|
||||||
)
|
|
||||||
|
|
||||||
# use that to find the actual delta
|
|
||||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
|
||||||
logger.debug("First tokens in arguments received: %s", arguments_delta)
|
|
||||||
|
|
||||||
delta = DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
DeltaToolCall(
|
||||||
index=self.current_tool_id,
|
index=i,
|
||||||
function=DeltaFunctionCall(
|
type="function",
|
||||||
arguments=arguments_delta
|
id=make_tool_call_id(),
|
||||||
).model_dump(exclude_none=True),
|
function=DeltaFunctionCall(name=name).model_dump(
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
|
||||||
|
|
||||||
# last case -- we have an update to existing arguments.
|
|
||||||
elif cur_arguments and prev_arguments:
|
|
||||||
# judge whether the tool_call_portion is a complete JSON
|
|
||||||
try:
|
|
||||||
json.loads(tool_call_portion)
|
|
||||||
is_complete_json = True
|
|
||||||
except Exception:
|
|
||||||
is_complete_json = False
|
|
||||||
|
|
||||||
# if the delta_text ends with a '}' and tool_call_portion is a
|
|
||||||
# complete JSON, then the last '}' does not belong to the
|
|
||||||
# arguments, so we should trim it off
|
|
||||||
if (
|
|
||||||
isinstance(delta_text, str)
|
|
||||||
and len(delta_text.rstrip()) >= 1
|
|
||||||
and delta_text.rstrip()[-1] == "}"
|
|
||||||
and is_complete_json
|
|
||||||
):
|
|
||||||
delta_text = delta_text.rstrip()[:-1]
|
|
||||||
|
|
||||||
logger.debug("got diff %s", delta_text)
|
|
||||||
|
|
||||||
delta = DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_id,
|
|
||||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(
|
|
||||||
exclude_none=True
|
exclude_none=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
)
|
||||||
|
|
||||||
|
# Stream back new tool args by diffing against what was sent.
|
||||||
|
args_diff = self._compute_args_diff(i, tc_json, is_complete)
|
||||||
|
if args_diff:
|
||||||
|
tool_call_deltas.append(
|
||||||
|
DeltaToolCall(
|
||||||
|
index=i,
|
||||||
|
function=DeltaFunctionCall(arguments=args_diff).model_dump(
|
||||||
|
exclude_none=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if content or tool_call_deltas:
|
||||||
|
return DeltaMessage(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_call_deltas,
|
||||||
)
|
)
|
||||||
self.streamed_args_for_tool[self.current_tool_id] += delta_text
|
|
||||||
|
|
||||||
# handle saving the state for the current tool into
|
return None
|
||||||
# the "prev" list for use in diffing for the next iteration
|
|
||||||
assert isinstance(current_tool_call, dict)
|
|
||||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
|
||||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
|
||||||
else:
|
|
||||||
self.prev_tool_call_arr.append(current_tool_call)
|
|
||||||
|
|
||||||
return delta
|
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error trying to handle streaming tool call.")
|
logger.exception("Error trying to handle streaming tool call.")
|
||||||
return None # do not stream a delta. skip this token ID.
|
return None
|
||||||
|
|||||||
@@ -16,23 +16,7 @@ class LongcatFlashToolParser(Hermes2ProToolParser):
|
|||||||
self.tool_call_end_token: str = "</longcat_tool_call>"
|
self.tool_call_end_token: str = "</longcat_tool_call>"
|
||||||
|
|
||||||
self.tool_call_regex = re.compile(
|
self.tool_call_regex = re.compile(
|
||||||
r"<longcat_tool_call>(.*?)</longcat_tool_call>|<longcat_tool_call>(.*)",
|
r"<longcat_tool_call>(.*?)</longcat_tool_call>"
|
||||||
|
r"|<longcat_tool_call>(.*)",
|
||||||
re.DOTALL,
|
re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
|
||||||
self.tool_call_start_token, add_special_tokens=False
|
|
||||||
)
|
|
||||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
|
||||||
self.tool_call_end_token, add_special_tokens=False
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tool_call_start_token_array = [
|
|
||||||
self.model_tokenizer.decode([token_id])
|
|
||||||
for token_id in self.tool_call_start_token_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
self.tool_call_end_token_array = [
|
|
||||||
self.model_tokenizer.decode([token_id])
|
|
||||||
for token_id in self.tool_call_end_token_ids
|
|
||||||
]
|
|
||||||
|
|||||||
Reference in New Issue
Block a user