From 2d7053438a112e2be55cf6d2bde9deb8a169d0a4 Mon Sep 17 00:00:00 2001 From: wangln19 <96399074+wangln19@users.noreply.github.com> Date: Tue, 27 Jan 2026 10:22:35 +0800 Subject: [PATCH] fix: preserve native tool call ID in multi-turn tool calling (#32768) Signed-off-by: wanglinian Signed-off-by: wangln19 <96399074+wangln19@users.noreply.github.com> Signed-off-by: Roger Wang Co-authored-by: Roger Wang Co-authored-by: Isotr0py <2037008807@qq.com> --- tests/entrypoints/openai/test_chat_error.py | 1 + tests/entrypoints/openai/test_serving_chat.py | 1 + .../openai/chat_completion/serving.py | 160 +++++++++++++----- vllm/entrypoints/openai/engine/protocol.py | 4 + vllm/entrypoints/openai/engine/serving.py | 1 + vllm/entrypoints/openai/responses/serving.py | 80 ++++++--- vllm/tool_parsers/kimi_k2_tool_parser.py | 4 +- 7 files changed, 183 insertions(+), 68 deletions(-) diff --git a/tests/entrypoints/openai/test_chat_error.py b/tests/entrypoints/openai/test_chat_error.py index a62340513..d42ae2557 100644 --- a/tests/entrypoints/openai/test_chat_error.py +++ b/tests/entrypoints/openai/test_chat_error.py @@ -42,6 +42,7 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + hf_text_config = MockHFConfig() logits_processor_pattern = None logits_processors: list[str] | None = None diff_sampling_param: dict | None = None diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 0f8de3435..fa29b31be 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -518,6 +518,7 @@ class MockModelConfig: tokenizer_revision = None multimodal_config = MultiModalConfig() hf_config = MockHFConfig() + hf_text_config = MockHFConfig() logits_processors: list[str] | None = None logits_processor_pattern = None diff_sampling_param: dict | None = None diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 265cee554..6a22bece6 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import ( DeltaMessage, DeltaToolCall, ErrorResponse, + FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, @@ -143,11 +144,6 @@ class OpenAIServingChat(OpenAIServing): self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage self.default_sampling_params = self.model_config.get_diff_sampling_param() - if self.model_config.hf_config.model_type == "kimi_k2": - self.tool_call_id_type = "kimi_k2" - else: - self.tool_call_id_type = "random" - self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" if self.use_harmony: if "stop_token_ids" not in self.default_sampling_params: @@ -156,6 +152,16 @@ class OpenAIServingChat(OpenAIServing): get_stop_tokens_for_assistant_actions() ) + # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides) + hf_overrides = getattr(self.model_config, "hf_overrides", None) + if self.model_config.hf_text_config.model_type == "kimi_k2" or ( + isinstance(hf_overrides, dict) + and hf_overrides.get("model_type") == "kimi_k2" + ): + self.tool_call_id_type = "kimi_k2" + else: + self.tool_call_id_type = "random" + # NOTE(woosuk): While OpenAI's chat completion API supports browsing # for some models, currently vLLM doesn't support it. Please use the # Responses API instead. @@ -247,8 +253,8 @@ class OpenAIServingChat(OpenAIServing): # because of issues with pydantic we need to potentially # re-serialize the tool_calls field of the request # for more info: see comment in `maybe_serialize_tool_calls` - maybe_serialize_tool_calls(request) - truncate_tool_call_ids(request) + maybe_serialize_tool_calls(request) # type: ignore[arg-type] + truncate_tool_call_ids(request) # type: ignore[arg-type] validate_request_params(request) # Check if tool parsing is unavailable (common condition) @@ -454,6 +460,7 @@ class OpenAIServingChat(OpenAIServing): # Streaming response tokenizer = self.renderer.tokenizer + assert tokenizer is not None if request.stream: return self.chat_completion_stream_generator( @@ -632,9 +639,11 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike | None, + tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: + from vllm.tokenizers.mistral import MistralTokenizer + created_time = int(time.time()) chunk_object_type: Final = "chat.completion.chunk" first_iteration = True @@ -698,7 +707,7 @@ class OpenAIServingChat(OpenAIServing): ) reasoning_parser = self.reasoning_parser( tokenizer, - chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] + chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg] ) except RuntimeError as e: logger.exception("Error in reasoning parser creation.") @@ -955,8 +964,17 @@ class OpenAIServingChat(OpenAIServing): index=i, ) else: + # Generate ID based on tokenizer type + if isinstance(tokenizer, MistralTokenizer): + tool_call_id = MistralToolCall.generate_random_id() + else: + tool_call_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_choice_function_name, + idx=history_tool_call_cnt, + ) delta_tool_call = DeltaToolCall( - id=make_tool_call_id(), + id=tool_call_id, type="function", function=DeltaFunctionCall( name=tool_choice_function_name, @@ -1387,9 +1405,11 @@ class OpenAIServingChat(OpenAIServing): request_id: str, model_name: str, conversation: list[ConversationMessage], - tokenizer: TokenizerLike | None, + tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, ) -> ErrorResponse | ChatCompletionResponse: + from vllm.tokenizers.mistral import MistralTokenizer + created_time = int(time.time()) final_res: RequestOutput | None = None @@ -1524,39 +1544,85 @@ class OpenAIServingChat(OpenAIServing): tool_call_class = ( MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall ) - if (not self.enable_auto_tools or not self.tool_parser) and ( + if self.use_harmony: + # Harmony models already have parsed content and tool_calls + # through parse_chat_output. Respect its output directly. + message = ChatMessage( + role=role, + reasoning=reasoning, + content=content, + tool_calls=tool_calls if tool_calls else [], + ) + + elif (not self.enable_auto_tools or not self.tool_parser) and ( not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) and request.tool_choice != "required" ): message = ChatMessage(role=role, reasoning=reasoning, content=content) - # if the request uses tools and specified a tool choice elif ( request.tool_choice and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam ): assert tool_calls is not None and len(tool_calls) > 0 + tool_call_class_items = [] + for idx, tc in enumerate(tool_calls): + # Use native ID if available (e.g., Kimi K2), + # otherwise generate ID with correct id_type + if tc.id: + tool_call_class_items.append( + tool_call_class(id=tc.id, function=tc) + ) + else: + # Generate ID using the correct format (kimi_k2 or random), + # but leave it to the class if it's Mistral to preserve + # 9-char IDs + if isinstance(tokenizer, MistralTokenizer): + tool_call_class_items.append(tool_call_class(function=tc)) + else: + generated_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tc.name, + idx=history_tool_call_cnt + idx, + ) + tool_call_class_items.append( + tool_call_class(id=generated_id, function=tc) + ) + history_tool_call_cnt += 1 message = ChatMessage( role=role, reasoning=reasoning, content="", - tool_calls=[tool_call_class(function=tc) for tc in tool_calls], + tool_calls=tool_call_class_items, ) elif request.tool_choice and request.tool_choice == "required": tool_call_class_items = [] assert tool_calls is not None and len(tool_calls) > 0 - for tool_call in tool_calls: - tool_call_class_items.append( - tool_call_class( - id=make_tool_call_id( + for idx, tool_call in enumerate(tool_calls): + # Use native ID if available, + # otherwise generate ID with correct id_type + if tool_call.id: + tool_call_class_items.append( + tool_call_class(id=tool_call.id, function=tool_call) + ) + else: + # Generate ID using the correct format (kimi_k2 or random), + # but leave it to the class if it's Mistral to preserve + # 9-char IDs + if isinstance(tokenizer, MistralTokenizer): + tool_call_class_items.append( + tool_call_class(function=tool_call) + ) + else: + generated_id = make_tool_call_id( id_type=self.tool_call_id_type, func_name=tool_call.name, - idx=history_tool_call_cnt, - ), - function=tool_call, - ) - ) + idx=history_tool_call_cnt + idx, + ) + tool_call_class_items.append( + tool_call_class(id=generated_id, function=tool_call) + ) history_tool_call_cnt += 1 message = ChatMessage( role=role, @@ -1582,17 +1648,35 @@ class OpenAIServingChat(OpenAIServing): # call. The same is not true for named function calls auto_tools_called = tool_calls is not None and len(tool_calls) > 0 if tool_calls: + tool_call_items = [] + for idx, tc in enumerate(tool_calls): + # Use native ID if available (e.g., Kimi K2), + # otherwise generate ID with correct id_type + if tc.id: + tool_call_items.append( + tool_call_class(id=tc.id, function=tc) + ) + else: + # Generate ID using the correct format (kimi_k2 or random), + # but leave it to the class if it's Mistral to preserve + # 9-char IDs + if isinstance(tokenizer, MistralTokenizer): + tool_call_items.append(tool_call_class(function=tc)) + else: + generated_id = make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tc.name, + idx=history_tool_call_cnt + idx, + ) + tool_call_items.append( + tool_call_class(id=generated_id, function=tc) + ) + history_tool_call_cnt += 1 message = ChatMessage( role=role, reasoning=reasoning, content=content, - tool_calls=[ - ToolCall( - function=tc, - type="function", - ) - for tc in tool_calls - ], + tool_calls=tool_call_items, ) else: @@ -1701,13 +1785,11 @@ class OpenAIServingChat(OpenAIServing): elif choice.message.tool_calls: # For tool calls, log the function name and arguments tool_call_descriptions = [] - for tc in choice.message.tool_calls: - if hasattr(tc.function, "name") and hasattr( - tc.function, "arguments" - ): - tool_call_descriptions.append( - f"{tc.function.name}({tc.function.arguments})" - ) + for tc in choice.message.tool_calls: # type: ignore + function_call: FunctionCall = tc.function # type: ignore + tool_call_descriptions.append( + f"{function_call.name}({function_call.arguments})" + ) tool_calls_str = ", ".join(tool_call_descriptions) output_text = f"[tool_calls: {tool_calls_str}]" @@ -1895,7 +1977,7 @@ class OpenAIServingChat(OpenAIServing): # because of issues with pydantic we need to potentially # re-serialize the tool_calls field of the request # for more info: see comment in `maybe_serialize_tool_calls` - maybe_serialize_tool_calls(request) + maybe_serialize_tool_calls(request) # type: ignore[arg-type] # Add system message. # NOTE: In Chat Completion API, browsing is enabled by default @@ -1913,7 +1995,7 @@ class OpenAIServingChat(OpenAIServing): # Add developer message. if request.tools: dev_msg = get_developer_message( - tools=request.tools if should_include_tools else None + tools=request.tools if should_include_tools else None # type: ignore[arg-type] ) messages.append(dev_msg) diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index e64cfd1c5..e491f9399 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -218,6 +218,10 @@ def get_logits_processors( class FunctionCall(OpenAIBaseModel): + # Internal field to preserve native tool call ID from tool parser. + # Excluded from serialization to maintain OpenAI API compatibility + # (function object should only contain 'name' and 'arguments'). + id: str | None = Field(default=None, exclude=True) name: str arguments: str diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index e05c287a0..243900262 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1525,6 +1525,7 @@ class OpenAIServing: # extract_tool_calls() returns a list of tool calls. function_calls.extend( FunctionCall( + id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments, ) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 9fa748f87..702167a24 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, + make_tool_call_id, ) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.mcp.tool_server import ToolServer @@ -250,6 +251,17 @@ class OpenAIServingResponses(OpenAIServing): self.default_sampling_params["stop_token_ids"].extend( get_stop_tokens_for_assistant_actions() ) + + # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides) + hf_overrides = getattr(self.model_config, "hf_overrides", None) + if self.model_config.hf_text_config.model_type == "kimi_k2" or ( + isinstance(hf_overrides, dict) + and hf_overrides.get("model_type") == "kimi_k2" + ): + self.tool_call_id_type = "kimi_k2" + else: + self.tool_call_id_type = "random" + self.enable_auto_tools = enable_auto_tools # set up tool use self.tool_parser = self._get_tool_parser( @@ -954,25 +966,28 @@ class OpenAIServingResponses(OpenAIServing): enable_auto_tools=self.enable_auto_tools, tool_parser_cls=self.tool_parser, ) - if content: - output_text = ResponseOutputText( - text=content, - annotations=[], # TODO - type="output_text", - logprobs=( - self._create_response_logprobs( - token_ids=final_output.token_ids, - logprobs=final_output.logprobs, - tokenizer=tokenizer, - top_logprobs=request.top_logprobs, - ) - if request.is_include_output_logprobs() - else None - ), - ) + + if content or (self.use_harmony and tool_calls): + res_text_part = None + if content: + res_text_part = ResponseOutputText( + text=content, + annotations=[], # TODO + type="output_text", + logprobs=( + self._create_response_logprobs( + token_ids=final_output.token_ids, + logprobs=final_output.logprobs, + tokenizer=tokenizer, + top_logprobs=request.top_logprobs, + ) + if request.is_include_output_logprobs() + else None + ), + ) message_item = ResponseOutputMessage( id=f"msg_{random_uuid()}", - content=[output_text], + content=[res_text_part] if res_text_part else [], role="assistant", status="completed", type="message", @@ -984,17 +999,28 @@ class OpenAIServingResponses(OpenAIServing): if message_item: outputs.append(message_item) if tool_calls: - tool_call_items = [ - ResponseFunctionToolCall( - id=f"fc_{random_uuid()}", - call_id=f"call_{random_uuid()}", - type="function_call", - status="completed", - name=tool_call.name, - arguments=tool_call.arguments, + # We use a simple counter for history_tool_call_count because + # we don't track the history of tool calls in the Responses API yet. + # This means that the tool call index will start from 0 for each + # request. + tool_call_items = [] + for history_tool_call_cnt, tool_call in enumerate(tool_calls): + tool_call_items.append( + ResponseFunctionToolCall( + id=f"fc_{random_uuid()}", + call_id=tool_call.id + if tool_call.id + else make_tool_call_id( + id_type=self.tool_call_id_type, + func_name=tool_call.name, + idx=history_tool_call_cnt, + ), + type="function_call", + status="completed", + name=tool_call.name, + arguments=tool_call.arguments, + ) ) - for tool_call in tool_calls - ] outputs.extend(tool_call_items) return outputs diff --git a/vllm/tool_parsers/kimi_k2_tool_parser.py b/vllm/tool_parsers/kimi_k2_tool_parser.py index 354ed412b..ed4795215 100644 --- a/vllm/tool_parsers/kimi_k2_tool_parser.py +++ b/vllm/tool_parsers/kimi_k2_tool_parser.py @@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser): if current_tool_call_matches: tool_id, tool_args = current_tool_call_matches.groups() tool_name = tool_id.split(":")[0].split(".")[-1] - current_tool_call["id"] = tool_id + current_tool_call["id"] = tool_id.strip() current_tool_call["name"] = tool_name current_tool_call["arguments"] = tool_args else: @@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser): if current_tool_call_name_matches: (tool_id_str,) = current_tool_call_name_matches.groups() tool_name = tool_id_str.split(":")[0].split(".")[-1] - current_tool_call["id"] = tool_id_str + current_tool_call["id"] = tool_id_str.strip() current_tool_call["name"] = tool_name current_tool_call["arguments"] = "" else: