diff --git a/tests/tool_parsers/test_deepseekv32_tool_parser.py b/tests/tool_parsers/test_deepseekv32_tool_parser.py new file mode 100644 index 000000000..14462da5b --- /dev/null +++ b/tests/tool_parsers/test_deepseekv32_tool_parser.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Unit tests for DeepSeekV32ToolParser. + +These tests use a minimal mock tokenizer so no real model weights are required. +""" + +import json +from unittest.mock import MagicMock + +import pytest + +from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Token IDs are not used by the V32 parser logic, so we only need the +# tokenizer object to be truthy (the parser checks `if not self.model_tokenizer`). +MOCK_TOKENIZER = MagicMock() +MOCK_TOKENIZER.get_vocab.return_value = {} + + +def make_parser() -> DeepSeekV32ToolParser: + return DeepSeekV32ToolParser(MOCK_TOKENIZER) + + +def make_tool_param(name: str, params: dict) -> MagicMock: + """Build a mock tool matching the ChatCompletionToolsParam shape.""" + tool = MagicMock() + tool.function.name = name + tool.function.parameters = params + return tool + + +def make_request(tools=None) -> MagicMock: + req = MagicMock() + req.tools = tools + return req + + +# Shorthand for the DSML tokens used throughout +FC_START = "<|DSML|function_calls>" +FC_END = "" +INV_START = '<|DSML|invoke name="' +INV_END = "" +PARAM_START = '<|DSML|parameter name="' +PARAM_END = "" + + +def build_tool_call(func_name: str, params: dict[str, str]) -> str: + """Build a complete model-output tool call string.""" + param_strs = "".join( + f'{PARAM_START}{k}" string="true">{v}{PARAM_END}' for k, v in params.items() + ) + return f'{FC_START}\n{INV_START}{func_name}">\n{param_strs}\n{INV_END}\n{FC_END}' + + +# --------------------------------------------------------------------------- +# Tests: DeepSeekV32ToolParser._convert_param_value +# --------------------------------------------------------------------------- + + +class TestConvertParamValue: + @pytest.fixture + def parser(self): + return make_parser() + + def test_null(self, parser): + assert parser._convert_param_value("null", "string") is None + assert parser._convert_param_value("NULL", "integer") is None + + def test_string(self, parser): + assert parser._convert_param_value("hello", "string") == "hello" + + def test_integer_valid(self, parser): + assert parser._convert_param_value("42", "integer") == 42 + + def test_integer_invalid_falls_back_to_str(self, parser): + assert parser._convert_param_value("abc", "int") == "abc" + + def test_number_float(self, parser): + assert parser._convert_param_value("3.14", "number") == pytest.approx(3.14) + + def test_number_whole_returns_int(self, parser): + assert parser._convert_param_value("5.0", "number") == 5 + assert isinstance(parser._convert_param_value("5.0", "number"), int) + + def test_boolean_true(self, parser): + assert parser._convert_param_value("true", "boolean") is True + assert parser._convert_param_value("1", "bool") is True + + def test_boolean_false(self, parser): + assert parser._convert_param_value("false", "boolean") is False + assert parser._convert_param_value("False", "bool") is False + + def test_object_valid_json(self, parser): + assert parser._convert_param_value('{"k": 1}', "object") == {"k": 1} + + def test_object_invalid_json_falls_back(self, parser): + assert parser._convert_param_value("not-json", "object") == "not-json" + + def test_array_valid_json(self, parser): + assert parser._convert_param_value("[1, 2]", "array") == [1, 2] + + def test_unknown_type_tries_json_then_string(self, parser): + assert parser._convert_param_value("123", "unknown") == 123 + assert parser._convert_param_value("hello", "unknown") == "hello" + + +# --------------------------------------------------------------------------- +# Tests: extract_tool_calls (non-streaming) +# --------------------------------------------------------------------------- + + +class TestExtractToolCalls: + @pytest.fixture + def parser(self): + return make_parser() + + def test_no_tool_call(self, parser): + result = parser.extract_tool_calls("just some text", None) + assert not result.tools_called + assert result.tool_calls == [] + assert result.content == "just some text" + + def test_single_tool_no_params(self, parser): + model_output = f'{FC_START}\n{INV_START}get_time">\n{INV_END}\n{FC_END}' + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_time" + assert json.loads(result.tool_calls[0].function.arguments) == {} + + def test_single_tool_with_params(self, parser): + model_output = build_tool_call( + "get_weather", {"location": "SF", "date": "2024-01-16"} + ) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 1 + tc = result.tool_calls[0] + assert tc.function.name == "get_weather" + assert json.loads(tc.function.arguments) == { + "location": "SF", + "date": "2024-01-16", + } + + def test_content_before_tool_call(self, parser): + model_output = "Sure, let me check! " + build_tool_call( + "get_weather", {"location": "NYC"} + ) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert result.content == "Sure, let me check! " + + def test_no_content_prefix_returns_none(self, parser): + model_output = build_tool_call("get_weather", {"location": "NYC"}) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert result.content is None + + def test_multiple_tools(self, parser): + model_output = ( + f"{FC_START}\n" + f'{INV_START}get_weather">\n' + f'{PARAM_START}location" string="true">SF{PARAM_END}\n' + f"{INV_END}\n" + f'{INV_START}get_weather">\n' + f'{PARAM_START}location" string="true">NYC{PARAM_END}\n' + f"{INV_END}\n" + f"{FC_END}" + ) + result = parser.extract_tool_calls(model_output, None) + assert result.tools_called + assert len(result.tool_calls) == 2 + assert json.loads(result.tool_calls[0].function.arguments) == {"location": "SF"} + assert json.loads(result.tool_calls[1].function.arguments) == { + "location": "NYC" + } + + +# --------------------------------------------------------------------------- +# Tests: extract_tool_calls_streaming +# --------------------------------------------------------------------------- + + +class TestExtractToolCallsStreaming: + """Simulate character-by-character streaming and verify reconstructed args.""" + + @pytest.fixture + def parser(self): + return make_parser() + + def _stream(self, parser, full_text: str, request=None): + """Drive the parser line-by-line and collect non-None deltas. + + Real tokenizers emit multi-character chunks, not individual characters. + Streaming character-by-character would never deliver the full sentinel + token (e.g. '|DSML|') in a single delta, so we split on newlines to + ensure each sentinel always lands in one chunk. + """ + if request is None: + request = make_request() + # Split into lines, preserving the trailing newline in each chunk. + chunks: list[str] = [] + remaining = full_text + while remaining: + nl = remaining.find("\n") + if nl == -1: + chunks.append(remaining) + break + chunks.append(remaining[: nl + 1]) + remaining = remaining[nl + 1 :] + + deltas = [] + prev = "" + for chunk in chunks: + curr = prev + chunk + result = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=curr, + delta_text=chunk, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[1], + request=request, + ) + prev = curr + if result is not None: + deltas.append(result) + return deltas + + def _reconstruct_args(self, deltas, tool_index=0) -> str: + """Concatenate all argument fragments for a given tool index.""" + fragments = [] + for d in deltas: + if d.tool_calls: + for tc in d.tool_calls: + if tc.index == tool_index and tc.function and tc.function.arguments: + fragments.append(tc.function.arguments) + return "".join(fragments) + + def test_plain_content_no_tool(self, parser): + full_text = "Hello, world!" + deltas = self._stream(parser, full_text) + content = "".join(d.content for d in deltas if d.content is not None) + assert "Hello, world!" in content + assert all(not d.tool_calls for d in deltas) + + def test_single_tool_streaming(self, parser): + full_text = build_tool_call("get_weather", {"location": "SF"}) + deltas = self._stream(parser, full_text) + args_str = self._reconstruct_args(deltas) + assert json.loads(args_str) == {"location": "SF"} + + def test_tool_name_emitted(self, parser): + full_text = build_tool_call("my_func", {"x": "1"}) + deltas = self._stream(parser, full_text) + func_names = [ + tc.function.name + for d in deltas + if d.tool_calls + for tc in d.tool_calls + if tc.function and tc.function.name + ] + assert any("my_func" in n for n in func_names) + + def test_content_before_tool_call_streaming(self, parser): + full_text = "Thinking... " + build_tool_call("fn", {"a": "b"}) + deltas = self._stream(parser, full_text) + content = "".join(d.content for d in deltas if d.content is not None) + assert "Thinking" in content + + def test_type_conversion_in_streaming(self, parser): + tool = make_tool_param( + "add", + { + "type": "object", + "properties": { + "x": {"type": "integer"}, + "y": {"type": "integer"}, + }, + }, + ) + request = make_request(tools=[tool]) + full_text = build_tool_call("add", {"x": "3", "y": "4"}) + deltas = self._stream(parser, full_text, request=request) + args_str = self._reconstruct_args(deltas) + assert json.loads(args_str) == {"x": 3, "y": 4} + + def test_multiple_tools_streaming(self, parser): + full_text = ( + f"{FC_START}\n" + f'{INV_START}func_a">\n' + f'{PARAM_START}p" string="true">v1{PARAM_END}\n' + f"{INV_END}\n" + f'{INV_START}func_b">\n' + f'{PARAM_START}q" string="true">v2{PARAM_END}\n' + f"{INV_END}\n" + f"{FC_END}" + ) + deltas = self._stream(parser, full_text) + + # Collect function names by index + names_by_index: dict[int, str] = {} + for d in deltas: + if d.tool_calls: + for tc in d.tool_calls: + if tc.function and tc.function.name: + names_by_index[tc.index] = tc.function.name + + assert names_by_index.get(0) == "func_a" + assert names_by_index.get(1) == "func_b" + + assert json.loads(self._reconstruct_args(deltas, tool_index=0)) == {"p": "v1"} + assert json.loads(self._reconstruct_args(deltas, tool_index=1)) == {"q": "v2"} + + def test_state_reset_on_new_stream(self, parser): + """A second stream (previous_text == '') must reset state cleanly.""" + full_text = build_tool_call("fn", {"k": "v"}) + # First stream + self._stream(parser, full_text) + # Second stream - should produce identical results + deltas2 = self._stream(parser, full_text) + assert json.loads(self._reconstruct_args(deltas2)) == {"k": "v"} + + def test_empty_arguments_streaming(self, parser): + """Invoke block with zero parameters should produce empty JSON.""" + full_text = f'{FC_START}\n{INV_START}get_time">\n{INV_END}\n{FC_END}' + deltas = self._stream(parser, full_text) + args_str = self._reconstruct_args(deltas) + assert json.loads(args_str) == {} + + def test_unique_tool_call_ids(self, parser): + """Each tool call in a parallel stream must get a distinct id.""" + full_text = ( + f"{FC_START}\n" + f'{INV_START}fn_a">\n' + f'{PARAM_START}x" string="true">1{PARAM_END}\n' + f"{INV_END}\n" + f'{INV_START}fn_b">\n' + f'{PARAM_START}y" string="true">2{PARAM_END}\n' + f"{INV_END}\n" + f"{FC_END}" + ) + deltas = self._stream(parser, full_text) + ids = [ + tc.id + for d in deltas + if d.tool_calls + for tc in d.tool_calls + if tc.id is not None + ] + assert len(ids) == 2 + assert ids[0] != ids[1] + + def test_eos_after_tool_calls(self, parser): + """EOS token (empty delta_text, non-empty delta_token_ids) returns + a non-None DeltaMessage so the serving framework can finalize.""" + full_text = build_tool_call("fn", {"k": "v"}) + # Drive through the full text first + deltas = self._stream(parser, full_text) + assert any(d.tool_calls for d in deltas) + # Now simulate EOS: empty delta_text, but token ids present + prev = full_text + result = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=prev, + delta_text="", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[2], # EOS token id + request=make_request(), + ) + assert result is not None + + def test_streaming_matches_non_streaming(self, parser): + """Streaming and non-streaming must produce the same result.""" + full_text = build_tool_call( + "get_weather", {"location": "SF", "date": "2024-01-16"} + ) + # Non-streaming + non_stream = parser.extract_tool_calls(full_text, None) + assert non_stream.tools_called + ns_name = non_stream.tool_calls[0].function.name + ns_args = json.loads(non_stream.tool_calls[0].function.arguments) + # Streaming + deltas = self._stream(parser, full_text) + s_names = [ + tc.function.name + for d in deltas + if d.tool_calls + for tc in d.tool_calls + if tc.function and tc.function.name + ] + s_args = json.loads(self._reconstruct_args(deltas)) + assert s_names[0] == ns_name + assert s_args == ns_args + + def _stream_chunked(self, parser, full_text: str, chunk_size: int, request=None): + """Drive the parser with fixed-size chunks (simulates stream interval). + + Unlike ``_stream`` which splits on newlines, this splits the text + into ``chunk_size``-character pieces so the start token can be + split across chunks — exactly what happens with stream interval > 1. + """ + if request is None: + request = make_request() + chunks = [ + full_text[i : i + chunk_size] for i in range(0, len(full_text), chunk_size) + ] + deltas = [] + prev = "" + for chunk in chunks: + curr = prev + chunk + result = parser.extract_tool_calls_streaming( + previous_text=prev, + current_text=curr, + delta_text=chunk, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[1], + request=request, + ) + prev = curr + if result is not None: + deltas.append(result) + return deltas + + def test_single_tool_chunked_stream_interval(self, parser): + """Start token split across chunks (stream interval > 1).""" + full_text = build_tool_call("get_weather", {"location": "SF"}) + # Use a chunk size that splits the start token + deltas = self._stream_chunked(parser, full_text, chunk_size=5) + args_str = self._reconstruct_args(deltas) + assert json.loads(args_str) == {"location": "SF"} + + def test_content_before_tool_chunked(self, parser): + """Content before tool call with chunked streaming.""" + full_text = "Thinking... " + build_tool_call("fn", {"a": "b"}) + deltas = self._stream_chunked(parser, full_text, chunk_size=7) + content = "".join(d.content for d in deltas if d.content is not None) + assert "Thinking" in content + args_str = self._reconstruct_args(deltas) + assert json.loads(args_str) == {"a": "b"} + + def test_multiple_tools_chunked(self, parser): + """Multiple tools with chunked streaming.""" + full_text = ( + f"{FC_START}\n" + f'{INV_START}func_a">\n' + f'{PARAM_START}p" string="true">v1{PARAM_END}\n' + f"{INV_END}\n" + f'{INV_START}func_b">\n' + f'{PARAM_START}q" string="true">v2{PARAM_END}\n' + f"{INV_END}\n" + f"{FC_END}" + ) + deltas = self._stream_chunked(parser, full_text, chunk_size=10) + assert json.loads(self._reconstruct_args(deltas, tool_index=0)) == {"p": "v1"} + assert json.loads(self._reconstruct_args(deltas, tool_index=1)) == {"q": "v2"} + + def test_no_emission_while_incomplete(self, parser): + """No tool calls should be emitted until an invoke block completes.""" + # Stream only a partial invoke (no closing tag) + partial_text = ( + f"{FC_START}\n" + f'{INV_START}fn">\n' + f'{PARAM_START}k" string="true">val{PARAM_END}\n' + ) + deltas = self._stream(parser, partial_text) + # Should have no tool call deltas yet + assert all(not d.tool_calls for d in deltas) diff --git a/vllm/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py index 30e23ed9f..cb39a16fd 100644 --- a/vllm/tool_parsers/deepseekv32_tool_parser.py +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -48,41 +48,12 @@ class DeepSeekV32ToolParser(ToolParser): self.prev_tool_call_arr: list[dict] = [] - # Sentinel tokens - self.dsml_token: str = "|DSML|" - self.dsml_start_check: str = "<" + self.dsml_token + # Sentinel token self.tool_call_start_token: str = "<|DSML|function_calls>" - self.tool_call_end_token: str = "" - self.invoke_start_prefix: str = "<|DSML|invoke name=" - self.invoke_end_token: str = "" - self.parameter_prefix: str = "<|DSML|parameter name=" - self.parameter_end_token: str = "" - # Streaming state variables - self.current_tool_name_sent: bool = False - # Override base class type - we use string IDs for tool calls - self.current_tool_id: str | None = None # type: ignore - self.streamed_args_for_tool: list[str] = [] + # Streaming state self.is_tool_call_started: bool = False - self.failed_count: int = 0 - - # Initialize streaming state variables self.current_tool_index: int = 0 - self.invoke_index: int = 0 - self.header_sent: bool = False - self.current_function_name: str | None = None - self.current_param_name: str | None = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.json_started: bool = False - self.json_closed: bool = False - self.accumulated_params: dict = {} - self.streaming_request: ChatCompletionRequest | None = None - - # Enhanced streaming state - reset for each new message - self._reset_streaming_state() # Regex patterns for complete parsing self.tool_call_complete_regex = re.compile( @@ -106,10 +77,6 @@ class DeepSeekV32ToolParser(ToolParser): "vLLM Successfully import tool parser %s !", self.__class__.__name__ ) - def _generate_tool_call_id(self) -> str: - """Generate a unique tool call ID.""" - return f"call_{uuid.uuid4().hex[:24]}" - def adjust_request(self, request): request = super().adjust_request(request) if request.tools and request.tool_choice != "none": @@ -122,33 +89,77 @@ class DeepSeekV32ToolParser(ToolParser): request.skip_special_tokens = False return request - def _reset_streaming_state(self): - """Reset all streaming state.""" - self.current_tool_index = 0 - self.invoke_index = 0 - self.is_tool_call_started = False - self.header_sent = False - self.current_tool_id = None - self.current_function_name = None - self.current_param_name = None - self.current_param_value = "" - self.param_count = 0 - self.in_param = False - self.in_function = False - self.json_started = False - self.json_closed = False - # Store accumulated parameters for type conversion - self.accumulated_params = {} - self.streaming_request = None - # Clear previous tool call history to avoid state pollution - self.prev_tool_call_arr.clear() + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" - def _parse_invoke_params(self, invoke_str: str) -> dict | None: + def _parse_invoke_params(self, invoke_str: str) -> dict: param_dict = dict() for param_name, param_val in self.parameter_complete_regex.findall(invoke_str): param_dict[param_name] = param_val return param_dict + def _convert_param_value(self, value: str, param_type: str) -> Any: + """Convert parameter value to the correct type.""" + if value.lower() == "null": + return None + + param_type = param_type.lower() + if param_type in ["string", "str", "text"]: + return value + elif param_type in ["integer", "int"]: + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ["number", "float"]: + try: + val = float(value) + return val if val != int(val) else int(val) + except (ValueError, TypeError): + return value + elif param_type in ["boolean", "bool"]: + return value.lower() in ["true", "1"] + elif param_type in ["object", "array"]: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + else: + # Try JSON parse first, fallback to string + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + def _convert_params_with_schema( + self, + function_name: str, + param_dict: dict[str, str], + request: ChatCompletionRequest | None, + ) -> dict[str, Any]: + """Convert raw string param values using the tool schema types.""" + param_config: dict = {} + if request and request.tools: + for tool in request.tools: + if ( + hasattr(tool, "function") + and tool.function.name == function_name + and hasattr(tool.function, "parameters") + ): + schema = tool.function.parameters + if isinstance(schema, dict) and "properties" in schema: + param_config = schema["properties"] + break + + converted: dict[str, Any] = {} + for name, value in param_dict.items(): + param_type = "string" + if name in param_config and isinstance(param_config[name], dict): + param_type = param_config[name].get("type", "string") + converted[name] = self._convert_param_value(value, param_type) + return converted + def extract_tool_calls( self, model_output: str, @@ -200,56 +211,55 @@ class DeepSeekV32ToolParser(ToolParser): tools_called=False, tool_calls=[], content=model_output ) - def _extract_name(self, name_str: str) -> str: - """Extract name from quoted string.""" - name_str = name_str.strip() - if ( - name_str.startswith('"') - and name_str.endswith('"') - or name_str.startswith("'") - and name_str.endswith("'") - ): - return name_str[1:-1] - return name_str + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.prev_tool_call_arr.clear() + self.streamed_args_for_tool.clear() - def _extract_param_name(self, input_str: str) -> str: - """Extract param name""" - start = input_str.find('"') + 1 - end = input_str.find('"', start) - return input_str[start:end] if start > 0 and end > start else input_str + def _extract_delta_tool_calls( + self, + current_text: str, + request: ChatCompletionRequest | None, + ) -> list[DeltaToolCall]: + """Extract DeltaToolCalls from newly completed blocks. - def _convert_param_value(self, value: str, param_type: str) -> Any: - """Convert parameter value to the correct type.""" - if value.lower() == "null": - return None + Tracks progress via ``current_tool_index`` so each block is + extracted exactly once across successive streaming calls. + """ + complete_invokes = self.invoke_complete_regex.findall(current_text) + delta_tool_calls: list[DeltaToolCall] = [] - param_type = param_type.lower() - if param_type in ["string", "str", "text"]: - return value - elif param_type in ["integer", "int"]: - try: - return int(value) - except (ValueError, TypeError): - return value - elif param_type in ["number", "float"]: - try: - val = float(value) - return val if val != int(val) else int(val) - except (ValueError, TypeError): - return value - elif param_type in ["boolean", "bool"]: - return value.lower() in ["true", "1"] - elif param_type in ["object", "array"]: - try: - return json.loads(value) - except json.JSONDecodeError: - return value - else: - # Try JSON parse first, fallback to string - try: - return json.loads(value) - except json.JSONDecodeError: - return value + while len(complete_invokes) > self.current_tool_index: + invoke_name, invoke_body = complete_invokes[self.current_tool_index] + param_dict = self._parse_invoke_params(invoke_body) + + converted = self._convert_params_with_schema( + invoke_name, param_dict, request + ) + args_json = json.dumps(converted, ensure_ascii=False) + idx = self.current_tool_index + self.current_tool_index += 1 + + self.prev_tool_call_arr.append( + {"name": invoke_name, "arguments": converted} + ) + self.streamed_args_for_tool.append(args_json) + + delta_tool_calls.append( + DeltaToolCall( + index=idx, + id=self._generate_tool_call_id(), + function=DeltaFunctionCall( + name=invoke_name, + arguments=args_json, + ), + type="function", + ) + ) + + return delta_tool_calls def extract_tool_calls_streaming( self, @@ -261,345 +271,44 @@ class DeepSeekV32ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - """Extract tool calls from streaming model output.""" + """Extract tool calls from streaming model output. - # Store request for type conversion + Uses a buffer-until-complete-invoke strategy: tokens are buffered + until a complete invoke block is available, then parsed and emitted + in one shot. + """ + + # First chunk of a new stream — reset state from prior request. if not previous_text: self._reset_streaming_state() - self.streaming_request = request - # If no delta text, return None unless it's an EOS token after tools - if not delta_text: - # Check if this is an EOS token after all tool calls are complete - if delta_token_ids: - # Count complete tool calls - complete_calls = len( - self.tool_call_complete_regex.findall(current_text) - ) - - # If we have completed tool calls and populated prev_tool_call_arr - if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: - # Check if all tool calls are closed - open_calls = current_text.count( - self.tool_call_start_token - ) - current_text.count(self.tool_call_end_token) - if open_calls == 0: - # Return empty delta for finish_reason processing - return DeltaMessage(content="") - elif not self.is_tool_call_started and current_text: - # This is a regular content response that's now complete - return DeltaMessage(content="") - return None - - # Check if we need to advance to next tool - if self.json_closed and not self.in_function: - # Check if this tool call has ended - invoke_ends = current_text.count(self.invoke_end_token) - if invoke_ends > self.current_tool_index: - # This tool has ended, advance to next - self.current_tool_index += 1 - self.header_sent = False - self.param_count = 0 - self.json_started = False - self.json_closed = False - self.in_function = False # Now we can safely set this to False - self.accumulated_params = {} - # Continue processing next tool - return None - - # Handle normal content before tool calls - if not self.is_tool_call_started: - # Check if tool call is starting - if self.dsml_token in current_text: - self.is_tool_call_started = True - # Return any content before the tool call - if self.dsml_start_check in delta_text: - content_before = delta_text[ - : delta_text.index(self.dsml_start_check) - ] - if content_before: - return DeltaMessage(content=content_before) - return None - else: - # Check if we're between tool calls - skip whitespace - if ( - current_text.rstrip().endswith(self.tool_call_end_token) - and delta_text.strip() == "" - ): - # We just ended a tool call, skip whitespace - return None - # Normal content, no tool call - if delta_text.endswith("<"): - return DeltaMessage(content=delta_text[:-1]) - if previous_text and previous_text.endswith("<"): - return DeltaMessage(content="<" + delta_text) - return DeltaMessage(content=delta_text) - - # Check if we're between tool calls (waiting for next one) - invoke_starts_count = current_text.count(self.invoke_start_prefix) - if self.current_tool_index >= invoke_starts_count: - # We're past all tool calls, shouldn't be here - return None - - # Find the current tool call portion - invoke_start_positions: list[int] = [] - idx = 0 - while True: - idx = current_text.find(self.invoke_start_prefix, idx) - if idx == -1: - break - invoke_start_positions.append(idx) - idx += len(self.invoke_start_prefix) - - if self.current_tool_index >= len(invoke_start_positions): - # No more tool calls to process yet - return None - - invoke_start_idx = invoke_start_positions[self.current_tool_index] - # Find where this tool call ends (or current position if not ended yet) - invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx) - if invoke_end_idx == -1: - tool_text = current_text[invoke_start_idx:] + # Detect whether we've entered the tool-call region. + # Use current_text (not delta_text) since the start token may + # be split across chunks. + content_before = None + if self.is_tool_call_started: + pass + elif self.tool_call_start_token in current_text: + # Tool-call region found, capture any plain text before it. + self.is_tool_call_started = True + start_idx = current_text.index(self.tool_call_start_token) + content_before = current_text[len(previous_text) : start_idx] or None else: - tool_text = current_text[ - invoke_start_idx : invoke_end_idx + len(self.invoke_end_token) - ] + # Still in plain-text region, forward as content. + return DeltaMessage(content=delta_text) if delta_text else None - # Looking for function header - if not self.header_sent: - if self.invoke_start_prefix in tool_text: - func_start = tool_text.find(self.invoke_start_prefix) + len( - self.invoke_start_prefix - ) - # Find the end quote for the function name - func_end = tool_text.find(">", func_start) + # Inside tool-call region: emit any newly completed invokes. + delta_tool_calls = self._extract_delta_tool_calls(current_text, request) - if func_end != -1: - # Found complete function name - function_name_raw = tool_text[func_start:func_end] - self.current_function_name = self._extract_name(function_name_raw) - self.current_tool_id = self._generate_tool_call_id() - self.header_sent = True - self.in_function = True + if delta_tool_calls or content_before: + return DeltaMessage( + content=content_before, + tool_calls=delta_tool_calls, + ) - # Add to prev_tool_call_arr immediately when we detect a tool call - # Each tool call should be recorded regardless of function name - # Ensure we don't add the same tool call index multiple times - if len(self.prev_tool_call_arr) <= self.current_tool_index: - self.prev_tool_call_arr.append( - { - "name": self.current_function_name, - "arguments": "{}", # Placeholder, will be updated later - } - ) - - # Send header with function info - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments="" - ), - type="function", - ) - ] - ) - return None - - # We've sent header, now handle function body - if self.in_function: - # Send opening brace if not sent yet - if self.in_function and not self.json_started: - self.json_started = True - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ] - ) - - # Make sure json_started is set if we're processing parameters - if not self.json_started: - self.json_started = True - - # Check for function end in accumulated text - if not self.json_closed and self.invoke_end_token in tool_text: - # Count total parameters in the tool text - total_param_count = tool_text.count(self.parameter_prefix) - - # Only close JSON if all parameters have been processed - if self.param_count >= total_param_count: - # Close JSON - self.json_closed = True - - # Extract complete tool call - # Find the invoke content - invoke_start = tool_text.find(self.invoke_start_prefix) + len( - self.invoke_start_prefix - ) - invoke_content_end = tool_text.find( - self.invoke_end_token, invoke_start - ) - if invoke_content_end != -1: - invoke_content = tool_text[invoke_start:invoke_content_end] - # Parse to get the complete arguments - try: - invoke_params = self._parse_invoke_params(invoke_content) - if invoke_params and self.current_tool_index < len( - self.prev_tool_call_arr - ): - # Update existing entry in prev_tool_call_arr - self.prev_tool_call_arr[self.current_tool_index][ - "arguments" - ] = json.dumps(invoke_params, ensure_ascii=False) - except Exception: - pass # Ignore parsing errors during streaming - - result = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ] - ) - - # Reset state for next tool - self.json_closed = True - self.in_function = False - self.accumulated_params = {} - - logger.debug("[M2_STREAMING] Tool call completed") - - return result - else: - # Don't close JSON yet, continue processing parameters - return None - - # Look for parameters - # Find all parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) - - # Check if we should start a new parameter - if ( - not self.in_param - and self.param_count < len(param_starts) - and len(param_starts) > self.param_count - ): - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] - - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - param_name_raw = remaining[:name_end] - self.current_param_name = self._extract_param_name(param_name_raw) - - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] - - # Find where this parameter ends - param_end_idx = value_text.find(self.parameter_end_token) - if param_end_idx == -1: - # No closing tag, look for next parameter or function end - next_param_idx = value_text.find(self.parameter_prefix) - func_end_idx = value_text.find(self.invoke_end_token) - - if next_param_idx != -1 and ( - func_end_idx == -1 or next_param_idx < func_end_idx - ): - param_end_idx = next_param_idx - elif func_end_idx != -1: - param_end_idx = func_end_idx - else: - # Neither found, check if tool call is complete - if self.invoke_end_token in tool_text: - # Tool call and parameter is complete - param_end_idx = len(value_text) - else: - # Still streaming, wait for more content - return None - - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Store raw value for later processing - self.accumulated_params[self.current_param_name] = param_value - - # Get parameter configuration for type conversion - param_config = {} - if self.streaming_request and self.streaming_request.tools: - for tool in self.streaming_request.tools: - if ( - hasattr(tool, "function") - and tool.function.name == self.current_function_name - and hasattr(tool.function, "parameters") - ): - params = tool.function.parameters - if ( - isinstance(params, dict) - and "properties" in params - ): - param_config = params["properties"] - break - - # Get parameter type - param_type = "string" - if ( - self.current_param_name in param_config - and isinstance(param_config[self.current_param_name], dict) - and "type" in param_config[self.current_param_name] - ): - param_type = param_config[self.current_param_name]["type"] - - # Convert param value to appropriate type - converted_value = self._convert_param_value( - param_value, param_type - ) - - # Build JSON fragment based on the converted type - # Use json.dumps to properly serialize the value - serialized_value = json.dumps( - converted_value, ensure_ascii=False - ) - - if self.param_count == 0: - json_fragment = ( - f'"{self.current_param_name}": {serialized_value}' - ) - else: - json_fragment = ( - f', "{self.current_param_name}": {serialized_value}' - ) - - self.param_count += 1 - - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments=json_fragment), - ) - ] - ) + # Empty delta with token ids means EOS or closing tag; return + # non-None so the serving framework can finalize finish_reason. + if not delta_text and delta_token_ids and self.prev_tool_call_arr: + return DeltaMessage(content="") return None