diff --git a/tests/tool_parsers/test_hermes_tool_parser.py b/tests/tool_parsers/test_hermes_tool_parser.py index 245f0739e..25869777d 100644 --- a/tests/tool_parsers/test_hermes_tool_parser.py +++ b/tests/tool_parsers/test_hermes_tool_parser.py @@ -152,6 +152,175 @@ def test_hermes_parser_streaming( } +def _simulate_streaming( + tokenizer: TokenizerLike, + parser: ToolParser, + request: ChatCompletionRequest, + text: str, + stream_interval: int = 1, +) -> list: + """Simulate streaming with a given stream_interval. + + Tokens are batched into chunks of `stream_interval` tokens, + mimicking how the output processor delivers them. + Returns a list of non-None DeltaMessages. + """ + tokens = tokenizer.encode(text) + previous_text = "" + delta_messages = [] + for i in range(0, len(tokens), stream_interval): + chunk_ids = tokens[i : i + stream_interval] + delta_text = tokenizer.decode(chunk_ids) + current_text = previous_text + delta_text + delta = parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=chunk_ids, + request=request, + ) + previous_text = current_text + if delta is not None: + delta_messages.append(delta) + return delta_messages + + +@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8]) +def test_hermes_streaming_tool_call_with_stream_interval( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, + stream_interval: int, +) -> None: + """Tool call streaming must produce correct name + args at any interval.""" + text = ( + '{"name": "get_current_temperature", ' + '"arguments": {"location": "San Francisco", "unit": "celsius"}}' + "" + ) + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, parser, any_chat_request, text, stream_interval + ) + + # Flatten all DeltaToolCalls across all deltas. + tool_deltas = [tc for d in deltas if d.tool_calls for tc in d.tool_calls] + assert tool_deltas, "Expected at least one tool call delta" + assert tool_deltas[0].function.name == "get_current_temperature" + + # Concatenated arguments must be valid JSON matching the original. + args_str = "".join(tc.function.arguments or "" for tc in tool_deltas) + assert json.loads(args_str) == { + "location": "San Francisco", + "unit": "celsius", + } + + +@pytest.mark.parametrize("stream_interval", [2, 3, 5, 8]) +def test_hermes_streaming_content_then_tool_call_with_stream_interval( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, + stream_interval: int, +) -> None: + """Content before a tool call must be fully streamed, then tool call.""" + text = ( + "Sure, let me check the weather." + '{"name": "get_weather", ' + '"arguments": {"city": "NYC"}}' + ) + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, parser, any_chat_request, text, stream_interval + ) + + content_deltas = [d for d in deltas if d.content] + tool_deltas = [d for d in deltas if d.tool_calls] + + # Content must reconstruct the prefix. + content_str = "".join(d.content for d in content_deltas) + assert content_str == "Sure, let me check the weather." + + # Tool call must be correct. + tool_calls = [tc for d in tool_deltas for tc in d.tool_calls] + assert tool_calls[0].function.name == "get_weather" + args_str = "".join(tc.function.arguments or "" for tc in tool_calls) + assert json.loads(args_str) == {"city": "NYC"} + + +@pytest.mark.parametrize("stream_interval", [1, 2, 4]) +def test_hermes_streaming_multiple_tool_calls_with_stream_interval( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, + stream_interval: int, +) -> None: + """Multiple sequential tool calls must each be streamed correctly.""" + text = ( + '{"name": "search", "arguments": {"q": "cats"}}' + '{"name": "search", "arguments": {"q": "dogs"}}' + ) + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, parser, any_chat_request, text, stream_interval + ) + + # Flatten all DeltaToolCalls across all deltas. + all_tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls] + + # Separate by tool index. + tool0 = [tc for tc in all_tool_calls if tc.index == 0] + tool1 = [tc for tc in all_tool_calls if tc.index == 1] + + assert tool0[0].function.name == "search" + args0 = "".join(tc.function.arguments or "" for tc in tool0) + assert json.loads(args0) == {"q": "cats"} + + assert tool1[0].function.name == "search" + args1 = "".join(tc.function.arguments or "" for tc in tool1) + assert json.loads(args1) == {"q": "dogs"} + + +@pytest.mark.parametrize("stream_interval", [2, 5]) +def test_hermes_streaming_boolean_args_with_stream_interval( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, + stream_interval: int, +) -> None: + """Regression test for bug #19056 with stream_interval > 1.""" + text = ( + "\n" + '{"name": "final_answer", "arguments": {"trigger": true}}\n' + "" + ) + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, parser, any_chat_request, text, stream_interval + ) + + tool_calls = [tc for d in deltas if d.tool_calls for tc in d.tool_calls] + assert tool_calls[0].function.name == "final_answer" + args_str = "".join(tc.function.arguments or "" for tc in tool_calls) + assert json.loads(args_str) == {"trigger": True} + + +@pytest.mark.parametrize("stream_interval", [2, 3, 5]) +def test_hermes_streaming_just_forward_text_with_stream_interval( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, + stream_interval: int, +) -> None: + """Plain text with no tool calls must be fully forwarded.""" + text = "This is plain text with no tool calling involved." + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, parser, any_chat_request, text, stream_interval + ) + + for d in deltas: + assert not d.tool_calls + assert "".join(d.content for d in deltas) == text + + def test_hermes_parser_non_streaming_no_tool_call( hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, @@ -218,3 +387,28 @@ def test_hermes_parser_non_streaming_tool_call_invalid_json( assert tool_call is not None assert not tool_call.tools_called + + +def test_hermes_streaming_content_and_tool_call_in_single_chunk( + qwen_tokenizer: TokenizerLike, + any_chat_request: ChatCompletionRequest, +) -> None: + """Content + complete tool call in one chunk must both be emitted.""" + text = 'Hi!{"name": "f", "arguments": {"x": 1}}' + # Use a stream_interval large enough to guarantee a single chunk. + parser = Hermes2ProToolParser(qwen_tokenizer) + deltas = _simulate_streaming( + qwen_tokenizer, + parser, + any_chat_request, + text, + stream_interval=9999, + ) + + content_parts = [d.content for d in deltas if d.content] + tool_parts = [tc for d in deltas if d.tool_calls for tc in d.tool_calls] + + assert "".join(content_parts) == "Hi!" + assert tool_parts[0].function.name == "f" + args_str = "".join(tc.function.arguments or "" for tc in tool_parts) + assert json.loads(args_str) == {"x": 1} diff --git a/vllm/tool_parsers/hermes_tool_parser.py b/vllm/tool_parsers/hermes_tool_parser.py index cca2bf9a0..4e54d75b4 100644 --- a/vllm/tool_parsers/hermes_tool_parser.py +++ b/vllm/tool_parsers/hermes_tool_parser.py @@ -4,9 +4,7 @@ import json from collections.abc import Sequence -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.openai.chat_completion.protocol import ( @@ -31,6 +29,27 @@ from vllm.utils.mistral import is_mistral_tokenizer logger = init_logger(__name__) +def _partial_tag_overlap(text: str, tag: str) -> int: + """Length of the longest prefix of `tag` that matches a suffix of `text`. + + E.g. text ending in "". + Returns 0 if there is no overlap. + """ + max_check = min(len(tag) - 1, len(text)) + for k in range(max_check, 0, -1): + if text.endswith(tag[:k]): + return k + return 0 + + +def _is_valid_json(text: str) -> bool: + try: + json.loads(text) + return True + except (json.JSONDecodeError, ValueError): + return False + + class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None): super().__init__(tokenizer, tools) @@ -39,13 +58,6 @@ class Hermes2ProToolParser(ToolParser): logger.error("Detected Mistral tokenizer when using a Hermes model") self.model_tokenizer = tokenizer.tokenizer - self.current_tool_name_sent: bool = False - self.prev_tool_call_arr: list[dict] = [] - self.current_tool_id: int = -1 - self.streamed_args_for_tool: list[ - str - ] = [] # map what has been streamed for each tool so far to a list - self.tool_call_start_token: str = "" self.tool_call_end_token: str = "" @@ -61,57 +73,9 @@ class Hermes2ProToolParser(ToolParser): "The model tokenizer must be passed to the ToolParser " "constructor during construction." ) - self.tool_call_start_token_ids = self.model_tokenizer.encode( - self.tool_call_start_token, add_special_tokens=False - ) - self.tool_call_end_token_ids = self.model_tokenizer.encode( - self.tool_call_end_token, add_special_tokens=False - ) - self.tool_call_start_token_array = [ - self.model_tokenizer.decode([token_id]) - for token_id in self.tool_call_start_token_ids - ] - - self.tool_call_end_token_array = [ - self.model_tokenizer.decode([token_id]) - for token_id in self.tool_call_end_token_ids - ] - - self.buffered_delta_text = "" - - # Very simple idea: when encountering tokens like <, tool, _call, >, - # <, /, tool, _call, >, store them in a buffer. - # When the last token is encountered, empty the buffer and return it. - # If a token appears in an incorrect sequence while storing in the buffer, - # return the preceding buffer along with the token. - def tool_call_delta_buffer(self, delta_text: str): - # If the sequence of tool_call_start or tool_call_end tokens is not yet - # complete, fill the buffer with the token and return "". - if ( - delta_text in self.tool_call_start_token_array - or delta_text in self.tool_call_end_token_array - ): - # If delta_text is the last token of tool_call_start_token or - # tool_call_end_token, empty the buffer and return - # the buffered text + delta_text. - if ( - delta_text == self.tool_call_start_token_array[-1] - or delta_text == self.tool_call_end_token_array[-1] - ): - buffered_text = self.buffered_delta_text - self.buffered_delta_text = "" - return buffered_text + delta_text - else: - self.buffered_delta_text = self.buffered_delta_text + delta_text - return "" - else: - if self.buffered_delta_text: - buffered_text = self.buffered_delta_text - self.buffered_delta_text = "" - return buffered_text + delta_text - else: - return delta_text + # Streaming state: what has been sent to the client. + self._sent_content_idx: int = 0 def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: request = super().adjust_request(request) @@ -174,6 +138,88 @@ class Hermes2ProToolParser(ToolParser): tools_called=False, tool_calls=[], content=model_output ) + def _extract_content(self, current_text: str) -> str | None: + """Return unsent non-tool-call text, or None. + + Holds back any suffix that could be a partial tag. + """ + if self.tool_call_start_token not in current_text: + overlap_length = _partial_tag_overlap( + current_text, self.tool_call_start_token + ) + sendable_idx = len(current_text) - overlap_length + else: + sendable_idx = current_text.index(self.tool_call_start_token) + + if sendable_idx > self._sent_content_idx: + content = current_text[self._sent_content_idx : sendable_idx] + self._sent_content_idx = sendable_idx + return content + return None + + def _extract_tool_call_jsons(self, text: str) -> list[tuple[str, bool]]: + """Extract (json_text, is_complete) for each region.""" + results: list[tuple[str, bool]] = [] + pos = 0 + while True: + start = text.find(self.tool_call_start_token, pos) + if start == -1: + break + json_start = start + len(self.tool_call_start_token) + json_end = text.find(self.tool_call_end_token, json_start) + if json_end != -1: + results.append((text[json_start:json_end].strip(), True)) + pos = json_end + len(self.tool_call_end_token) + else: + raw = text[json_start:] + # Strip partial suffix if present. + overlap = _partial_tag_overlap(raw, self.tool_call_end_token) + if overlap: + raw = raw[:-overlap] + tc_json = raw.strip() + # Valid JSON without closing tag = complete body, + # tag tokens just haven't arrived yet. + is_complete = _is_valid_json(tc_json) if tc_json else False + results.append((tc_json, is_complete)) + break + return results + + @staticmethod + def _extract_tool_name(tc_json: str) -> str | None: + """Extract tool name, or None if the name isn't complete yet.""" + match = re.search(r'"name"\s*:\s*"([^"]+)"', tc_json) + return match.group(1) if match else None + + @staticmethod + def _extract_tool_args(tc_json: str, is_complete: bool) -> str | None: + """Extract tool arguments from the tool call JSON. + + Given {"name": "f", "arguments": {"x": 1}}, returns '{"x": 1}'. + When is_complete, strips the trailing '}' that closes the outer + object (not the arguments). For partial JSON, returns as-is. + """ + match = re.search(r'"arguments"\s*:\s*', tc_json) + if not match: + return None + raw = tc_json[match.end() :] + if is_complete: + raw = raw.rstrip() + if raw.endswith("}"): + raw = raw[:-1].rstrip() + return raw + + def _compute_args_diff( + self, index: int, tc_json: str, is_complete: bool + ) -> str | None: + """Return new argument text not yet sent for tool `index`, or None.""" + args = self._extract_tool_args(tc_json, is_complete) + if args is None or len(args) <= len(self.streamed_args_for_tool[index]): + return None + diff = args[len(self.streamed_args_for_tool[index]) :] + self.streamed_args_for_tool[index] = args + self.prev_tool_call_arr[index]["arguments"] = args + return diff + def extract_tool_calls_streaming( self, previous_text: str, @@ -184,321 +230,64 @@ class Hermes2ProToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - # 1. All tokens are parsed based on _text, not token_ids. - # 2. All incoming text data is processed by the tool_call_delta_buffer - # function for buffering before being used for parsing. + """Incrementally stream tool call deltas from accumulated output. - delta_text = self.tool_call_delta_buffer(delta_text) - # If the last characters of previous_text - # match self.buffered_delta_text, remove only the matching part. - if ( - len(previous_text) >= len(self.buffered_delta_text) - and previous_text[-len(self.buffered_delta_text) :] - == self.buffered_delta_text - ): - previous_text = previous_text[: -len(self.buffered_delta_text)] - current_text = previous_text + delta_text - - logger.debug("delta_text: %s", delta_text) - logger.debug("delta_token_ids: %s", delta_token_ids) - # check to see if we should be streaming a tool call - is there a - if self.tool_call_start_token not in current_text: - logger.debug("No tool call tokens found!") - return DeltaMessage(content=delta_text) + On each invocation, re-parses the full ``current_text`` to find + ```` regions, then diffs against previously sent state + to emit only new content, tool names, or argument fragments. + Returns a ``DeltaMessage`` containing either plain content (for + text preceding any tool call) or one or more ``DeltaToolCall`` + entries, or ``None`` if there is nothing new to send yet.""" try: - # figure out where we are in the parsing by counting tool call - # start & end tags - prev_tool_start_count = previous_text.count(self.tool_call_start_token) - prev_tool_end_count = previous_text.count(self.tool_call_end_token) - cur_tool_start_count = current_text.count(self.tool_call_start_token) - cur_tool_end_count = current_text.count(self.tool_call_end_token) - tool_call_portion = None - text_portion = None + # Extract any content before tool calls. + content = self._extract_content(current_text) + tool_call_jsons = self._extract_tool_call_jsons(current_text) + tool_call_deltas: list[DeltaToolCall] = [] - # case: if we're generating text, OR rounding out a tool call - if ( - cur_tool_start_count == cur_tool_end_count - and prev_tool_end_count == cur_tool_end_count - and self.tool_call_end_token not in delta_text - ): - logger.debug("Generating text content! skipping tool parsing.") - return DeltaMessage(content=delta_text) + for i, (tc_json, is_complete) in enumerate(tool_call_jsons): + if i >= len(self.prev_tool_call_arr): + self.prev_tool_call_arr.append({}) + self.streamed_args_for_tool.append("") - if self.tool_call_end_token in delta_text: - logger.debug("tool_call_end_token in delta_text") - full_text = current_text + delta_text - tool_call_portion = ( - full_text.split(self.tool_call_start_token)[-1] - .split(self.tool_call_end_token)[0] - .rstrip() - ) - delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip() - text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip() - - # case: if tool open & close tag counts don't match, we're doing - # imaginary "else" block here - # something with tools with this diff. - # flags for partial JSON parting. exported constants from - # "Allow" are handled via BIT MASK - flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR - - # case -- we're starting a new tool call - if ( - cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count > prev_tool_start_count - ): - if len(delta_token_ids) > 1: - tool_call_portion = current_text.split(self.tool_call_start_token)[ - -1 - ] - else: - tool_call_portion = None - delta = None - - text_portion = None - - # set cursors and state appropriately - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("Starting on a new tool %s", self.current_tool_id) - - # case -- we're updating an existing tool call - elif ( - cur_tool_start_count > cur_tool_end_count - and cur_tool_start_count == prev_tool_start_count - ): - # get the portion of the text that's the tool call - tool_call_portion = current_text.split(self.tool_call_start_token)[-1] - text_portion = None - - # case -- the current tool call is being closed. - elif ( - cur_tool_start_count == cur_tool_end_count - and cur_tool_end_count >= prev_tool_end_count - ): - if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0: - logger.debug("attempting to close tool call, but no tool call") - return None - diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments") - if diff: - diff = ( - diff.encode("utf-8").decode("unicode_escape") - if diff is str - else diff - ) - if '"}' not in delta_text: - return None - end_loc = delta_text.rindex('"}') - diff = delta_text[:end_loc] + '"}' - logger.debug( - "Finishing tool and found diff that had not " - "been streamed yet: %s", - diff, - ) - self.streamed_args_for_tool[self.current_tool_id] += diff - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall(arguments=diff).model_dump( - exclude_none=True - ), - ) - ] - ) - - # case -- otherwise we're just generating text - else: - text = delta_text.replace(self.tool_call_start_token, "") - text = text.replace(self.tool_call_end_token, "") - delta = DeltaMessage(tool_calls=[], content=text) - return delta - - try: - current_tool_call = ( - partial_json_parser.loads(tool_call_portion or "{}", flags) - if tool_call_portion - else None - ) - logger.debug("Parsed tool call %s", current_tool_call) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug("not enough tokens to parse into JSON yet") - return None - except json.decoder.JSONDecodeError: - logger.debug("unable to parse JSON") - return None - - if current_tool_call is None: - return None - - # case - we haven't sent the tool name yet. If it's available, send - # it. otherwise, wait until it's available. - if not self.current_tool_name_sent: - function_name: str | None = current_tool_call.get("name") - if function_name: - self.current_tool_name_sent = True - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - type="function", - id=make_tool_call_id(), - function=DeltaFunctionCall( - name=function_name - ).model_dump(exclude_none=True), - ) - ] - ) - else: - return None - # case -- otherwise, send the tool call delta - - # if the tool call portion is None, send the delta as text - if tool_call_portion is None: - # if there's text but not tool calls, send that - - # otherwise None to skip chunk - delta = ( - DeltaMessage(content=delta_text) - if text_portion is not None - else None - ) - return delta - - # now, the nitty-gritty of tool calls - # now we have the portion to parse as tool call. - - if current_tool_call is None: - return None - - logger.debug( - "Trying to parse current tool call with ID %s", self.current_tool_id - ) - - # if we're starting a new tool call, push an empty object in as - # a placeholder for the arguments - if len(self.prev_tool_call_arr) <= self.current_tool_id: - self.prev_tool_call_arr.append({}) - - # main logic for tool parsing here - compare prev. partially-parsed - # JSON to the current partially-parsed JSON - prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( - "arguments" - ) - assert current_tool_call is not None - cur_arguments = current_tool_call.get("arguments") - - logger.debug("diffing old arguments: %s", prev_arguments) - logger.debug("against new ones: %s", cur_arguments) - - # case -- no arguments have been created yet. skip sending a delta. - if not cur_arguments and not prev_arguments: - logger.debug("Skipping text %s - no arguments", delta_text) - delta = None - - # case -- prev arguments are defined, but non are now. - # probably impossible, but not a fatal error - just keep going - elif not cur_arguments and prev_arguments: - logger.error( - "should be impossible to have arguments reset " - "mid-call. skipping streaming anything." - ) - delta = None - - # case -- we now have the first info about arguments available from - # autocompleting the JSON - elif cur_arguments and not prev_arguments: - # extract the content after {"name": ..., "arguments": - # directly from tool_call_portion as cur_arguments_json, - # since cur_arguments may differ from the original text - # due to partial JSON parsing - # for example, tool_call_portion = - # {"name": "search", "arguments": {"search_request": {" - # but cur_arguments = - # {"search_request": {}} - function_name = current_tool_call.get("name") - match = re.search( - r'\{"name":\s*"' - + re.escape(function_name) - + r'"\s*,\s*"arguments":\s*(.*)', - tool_call_portion.strip(), - re.DOTALL, - ) - if match: - cur_arguments_json = match.group(1) - else: - cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False) - - logger.debug("finding %s in %s", delta_text, cur_arguments_json) - - # get the location where previous args differ from current. - if delta_text not in cur_arguments_json: - return None - args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len( - delta_text - ) - - # use that to find the actual delta - arguments_delta = cur_arguments_json[:args_delta_start_loc] - logger.debug("First tokens in arguments received: %s", arguments_delta) - - delta = DeltaMessage( - tool_calls=[ + # Stream back tool name. + if "name" not in self.prev_tool_call_arr[i]: + name = self._extract_tool_name(tc_json) + if not name: + # Can't skip to tool i+1 if i isn't ready + break + self.prev_tool_call_arr[i]["name"] = name + tool_call_deltas.append( DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta - ).model_dump(exclude_none=True), - ) - ] - ) - self.streamed_args_for_tool[self.current_tool_id] += arguments_delta - - # last case -- we have an update to existing arguments. - elif cur_arguments and prev_arguments: - # judge whether the tool_call_portion is a complete JSON - try: - json.loads(tool_call_portion) - is_complete_json = True - except Exception: - is_complete_json = False - - # if the delta_text ends with a '}' and tool_call_portion is a - # complete JSON, then the last '}' does not belong to the - # arguments, so we should trim it off - if ( - isinstance(delta_text, str) - and len(delta_text.rstrip()) >= 1 - and delta_text.rstrip()[-1] == "}" - and is_complete_json - ): - delta_text = delta_text.rstrip()[:-1] - - logger.debug("got diff %s", delta_text) - - delta = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall(arguments=delta_text).model_dump( + index=i, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall(name=name).model_dump( exclude_none=True ), ) - ] + ) + + # Stream back new tool args by diffing against what was sent. + args_diff = self._compute_args_diff(i, tc_json, is_complete) + if args_diff: + tool_call_deltas.append( + DeltaToolCall( + index=i, + function=DeltaFunctionCall(arguments=args_diff).model_dump( + exclude_none=True + ), + ) + ) + + if content or tool_call_deltas: + return DeltaMessage( + content=content, + tool_calls=tool_call_deltas, ) - self.streamed_args_for_tool[self.current_tool_id] += delta_text - # handle saving the state for the current tool into - # the "prev" list for use in diffing for the next iteration - assert isinstance(current_tool_call, dict) - if self.current_tool_id == len(self.prev_tool_call_arr) - 1: - self.prev_tool_call_arr[self.current_tool_id] = current_tool_call - else: - self.prev_tool_call_arr.append(current_tool_call) - - return delta + return None except Exception: logger.exception("Error trying to handle streaming tool call.") - return None # do not stream a delta. skip this token ID. + return None diff --git a/vllm/tool_parsers/longcat_tool_parser.py b/vllm/tool_parsers/longcat_tool_parser.py index 0304f452e..ccfa94167 100644 --- a/vllm/tool_parsers/longcat_tool_parser.py +++ b/vllm/tool_parsers/longcat_tool_parser.py @@ -16,23 +16,7 @@ class LongcatFlashToolParser(Hermes2ProToolParser): self.tool_call_end_token: str = "" self.tool_call_regex = re.compile( - r"(.*?)|(.*)", + r"(.*?)" + r"|(.*)", re.DOTALL, ) - - self.tool_call_start_token_ids = self.model_tokenizer.encode( - self.tool_call_start_token, add_special_tokens=False - ) - self.tool_call_end_token_ids = self.model_tokenizer.encode( - self.tool_call_end_token, add_special_tokens=False - ) - - self.tool_call_start_token_array = [ - self.model_tokenizer.decode([token_id]) - for token_id in self.tool_call_start_token_ids - ] - - self.tool_call_end_token_array = [ - self.model_tokenizer.decode([token_id]) - for token_id in self.tool_call_end_token_ids - ]