diff --git a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py index 99ab1e497..f29f79f72 100644 --- a/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_gigachat3_tool_parser.py @@ -13,6 +13,13 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser, ToolParserManager +MSG_SEP_TOKEN = "<|message_sep|>\n\n" +ROLE_SEP_TOKEN = "<|role_sep|>\n" +EOS_TOKEN = "" +TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}" +TOOL_HEADER_GIGACHAT31 = "<|function_call|>" + + SIMPLE_ARGS_DICT = { "action": "create", "id": "preferences", @@ -24,7 +31,10 @@ SIMPLE_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON +SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}" +) +SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}" SIMPLE_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False), @@ -38,7 +48,12 @@ PARAMETERLESS_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON +PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}" +) +PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = ( + f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}" +) PARAMETERLESS_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps({}, ensure_ascii=False), @@ -62,17 +77,38 @@ COMPLEX_FUNCTION_JSON = json.dumps( }, ensure_ascii=False, ) -COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON +COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = ( + f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}" +) +COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}" COMPLEX_FUNCTION_CALL = FunctionCall( name="manage_user_memory", arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False), ) +CONTENT_TEXT = "I'll check that for you." +MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}" +MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}" + + +@pytest.fixture(name="gigachat_tokenizer") +def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike): + default_tokenizer.add_tokens( + [ + MSG_SEP_TOKEN, + ROLE_SEP_TOKEN, + TOOL_HEADER_GIGACHAT31, + EOS_TOKEN, + ] + ) + return default_tokenizer + + @pytest.mark.parametrize("streaming", [True, False]) -def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): +def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike): tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) model_output = "How can I help you today?" content, tool_calls = run_tool_extraction( @@ -85,45 +121,143 @@ def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike): TEST_CASES = [ pytest.param( True, - SIMPLE_FUNCTION_OUTPUT, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT3, [SIMPLE_FUNCTION_CALL], None, - id="simple_streaming", + id="simple_streaming_gigachat3", ), pytest.param( False, - SIMPLE_FUNCTION_OUTPUT, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT3, [SIMPLE_FUNCTION_CALL], None, - id="simple_nonstreaming", + id="simple_nonstreaming_gigachat3", ), pytest.param( True, - PARAMETERLESS_FUNCTION_OUTPUT, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3, [PARAMETERLESS_FUNCTION_CALL], None, - id="parameterless_streaming", + id="parameterless_streaming_gigachat3", ), pytest.param( False, - PARAMETERLESS_FUNCTION_OUTPUT, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3, [PARAMETERLESS_FUNCTION_CALL], None, - id="parameterless_nonstreaming", + id="parameterless_nonstreaming_gigachat3", ), pytest.param( True, - COMPLEX_FUNCTION_OUTPUT, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT3, [COMPLEX_FUNCTION_CALL], None, - id="complex_streaming", + id="complex_streaming_gigachat3", ), pytest.param( False, - COMPLEX_FUNCTION_OUTPUT, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT3, [COMPLEX_FUNCTION_CALL], None, - id="complex_nonstreaming", + id="complex_nonstreaming_gigachat3", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_gigachat3", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT3, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_gigachat3", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_with_eos_gigachat3", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_with_eos_gigachat3", + ), + pytest.param( + True, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_streaming_gigachat31", + ), + pytest.param( + False, + SIMPLE_FUNCTION_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + None, + id="simple_nonstreaming_gigachat31", + ), + pytest.param( + True, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_streaming_gigachat31", + ), + pytest.param( + False, + PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31, + [PARAMETERLESS_FUNCTION_CALL], + None, + id="parameterless_nonstreaming_gigachat31", + ), + pytest.param( + True, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT31, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_streaming_gigachat31", + ), + pytest.param( + False, + COMPLEX_FUNCTION_OUTPUT_GIGACHAT31, + [COMPLEX_FUNCTION_CALL], + None, + id="complex_nonstreaming_gigachat31", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_gigachat31", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT31, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_gigachat31", + ), + pytest.param( + True, + MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_streaming_with_eos_gigachat31", + ), + pytest.param( + False, + MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN, + [SIMPLE_FUNCTION_CALL], + CONTENT_TEXT, + id="mixed_content_nonstreaming_with_eos_gigachat31", ), ] @@ -136,14 +270,16 @@ def test_tool_call( model_output: str, expected_tool_calls: list[FunctionCall], expected_content: str | None, - default_tokenizer: TokenizerLike, + gigachat_tokenizer: TokenizerLike, ): tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) content, tool_calls = run_tool_extraction( tool_parser, model_output, streaming=streaming ) + if content == "": + content = None assert content == expected_content assert len(tool_calls) == len(expected_tool_calls) for actual, expected in zip(tool_calls, expected_tool_calls): @@ -154,15 +290,46 @@ def test_tool_call( assert actual_args == expected_args -def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike): +@pytest.mark.parametrize( + "model_output_deltas", + [ + pytest.param( + [ + CONTENT_TEXT[:3], + CONTENT_TEXT[3:5], + CONTENT_TEXT[5:], + MSG_SEP_TOKEN, + TOOL_HEADER_GIGACHAT3, + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:-1], + COMPLEX_FUNCTION_JSON[-1], + ], + id="gigachat3", + ), + pytest.param( + [ + CONTENT_TEXT[:3], + CONTENT_TEXT[3:5], + CONTENT_TEXT[5:], + TOOL_HEADER_GIGACHAT31, + COMPLEX_FUNCTION_JSON[:40], + COMPLEX_FUNCTION_JSON[40:-1], + COMPLEX_FUNCTION_JSON[-1], + ], + id="gigachat31", + ), + ], +) +def test_streaming_tool_call_with_large_steps( + model_output_deltas: list[str], + gigachat_tokenizer: TokenizerLike, +): + """ + Test that the closing braces are streamed correctly. + """ tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")( - default_tokenizer + gigachat_tokenizer ) - model_output_deltas = [ - "function call", - COMPLEX_FUNCTION_JSON[:40], - COMPLEX_FUNCTION_JSON[40:], - ] reconstructor = run_tool_extraction_streaming( tool_parser, model_output_deltas, diff --git a/vllm/tool_parsers/gigachat3_tool_parser.py b/vllm/tool_parsers/gigachat3_tool_parser.py index 02cdad9ed..90928f9ae 100644 --- a/vllm/tool_parsers/gigachat3_tool_parser.py +++ b/vllm/tool_parsers/gigachat3_tool_parser.py @@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser logger = init_logger(__name__) REGEX_FUNCTION_CALL = re.compile( - r"function call(?:<\|role_sep\|>\n)?(\{.*)", + r"(?:function call<\|role_sep\|>\n|<\|function_call\|>)(.*)", + re.DOTALL, +) + +REGEX_CONTENT_PATTERN = re.compile( + r"^(.*?)(?:<\|message_sep\|>|<\|function_call\|>)", re.DOTALL, ) @@ -47,57 +52,67 @@ class GigaChat3ToolParser(ToolParser): self.tool_name_sent: bool = False self.tool_id: str | None = None self.prev_tool_call_arr: list[dict] = [] - self.content_buffer: str = "" - self.trigger_start = "function call{" + self.end_content: bool = False + self.streamed_args_for_tool: list[str] = [] + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + return request def extract_tool_calls( self, model_output: str, request: ChatCompletionRequest, ) -> ExtractedToolCallInformation: - match = REGEX_FUNCTION_CALL.search(model_output) - if not match: + function_call = None + content = None + if model_output.rstrip().endswith(""): + model_output = model_output[: model_output.rfind("")] + m_func = REGEX_FUNCTION_CALL.search(model_output) + if m_func: + try: + function_call = json.loads(m_func.group(1), strict=False) + if ( + isinstance(function_call, dict) + and "name" in function_call + and "arguments" in function_call + ): + if not isinstance(function_call["arguments"], dict): + function_call = None + else: + function_call = None + except json.JSONDecodeError: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + m_content = REGEX_CONTENT_PATTERN.search(model_output) + content = m_content.group(1) if m_content else model_output + if not function_call: return ExtractedToolCallInformation( tools_called=False, tool_calls=[], - content=model_output, + content=content if content else None, ) - json_candidate = match.group(1).strip() - try: - data = json.loads(json_candidate) - except json.JSONDecodeError: - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output, - ) - if not (isinstance(data, dict) and "name" in data and "arguments" in data): - return ExtractedToolCallInformation( - tools_called=False, - tool_calls=[], - content=model_output, - ) - name = data["name"] - args = data["arguments"] + name = function_call["name"] + args = function_call["arguments"] if not isinstance(args, str): - args = json.dumps(args, ensure_ascii=False) - - tool_calls = [ - ToolCall( - type="function", - function=FunctionCall( - name=name, - arguments=args, - ), - ) - ] - prefix = model_output[: match.start()] - content = prefix.rstrip() if prefix and prefix.strip() else None - + args = json.dumps(function_call["arguments"], ensure_ascii=False) return ExtractedToolCallInformation( tools_called=True, - tool_calls=tool_calls, - content=content, + tool_calls=[ + ToolCall( + type="function", + function=FunctionCall( + name=name, + arguments=args, + ), + ) + ], + content=content if content else None, ) def extract_tool_calls_streaming( @@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: + content = None func_name = None cur_args = None + m_func = REGEX_FUNCTION_CALL.search(current_text) if not self.tool_started: - match = REGEX_FUNCTION_CALL.search(current_text) - if match: - self.tool_started = True - self.content_buffer = "" + m_content = REGEX_CONTENT_PATTERN.search(delta_text) + if m_content: + content = m_content.group(1) + self.end_content = True else: - self.content_buffer += delta_text - clean_buffer = self.content_buffer.lstrip() - is_prefix = self.trigger_start.startswith(clean_buffer) - starts_with_trigger = clean_buffer.startswith(self.trigger_start) - if is_prefix or starts_with_trigger: - return None - else: - flush_text = self.content_buffer - self.content_buffer = "" - return DeltaMessage(content=flush_text) - - match = REGEX_FUNCTION_CALL.search(current_text) - if not match: + if not self.end_content: + content = delta_text + if m_func: + self.tool_started = True + if content: + return DeltaMessage(content=content) + if not m_func: return None - json_tail = match.group(1).strip() + json_tail = m_func.group(1).strip() name_match = NAME_REGEX.search(json_tail) if name_match: func_name = name_match.group(1) args_match = ARGS_REGEX.search(json_tail) if args_match: cur_args = args_match.group(1).strip() + if cur_args.endswith(""): + cur_args = cur_args[: -len("")] if cur_args.endswith("}"): # last '}' end of json try: candidate = cur_args[:-1].strip() - json.loads(candidate) + json.loads(candidate, strict=False) cur_args = candidate except json.JSONDecodeError: pass @@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser): ).model_dump(exclude_none=True), ) ], - content=None, ) if cur_args is None: return None - prev_args = self.prev_tool_call_arr[0].get("arguments", "") + prev_args = self.prev_tool_call_arr[0].get("arguments_str", "") if not prev_args: delta_args = cur_args elif cur_args.startswith(prev_args): @@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser): return None if not delta_args: return None - self.prev_tool_call_arr[0]["arguments"] = cur_args + self.prev_tool_call_arr[0]["arguments_str"] = cur_args + try: + args_dict = json.loads(cur_args, strict=False) + self.prev_tool_call_arr[0]["arguments"] = args_dict + except json.JSONDecodeError: + self.prev_tool_call_arr[0]["arguments"] = {} + if len(self.streamed_args_for_tool) <= 0: + self.streamed_args_for_tool.append("") + self.streamed_args_for_tool[0] = cur_args return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser): ).model_dump(exclude_none=True), ) ], - content=None, )