fix: preserve native tool call ID in multi-turn tool calling (#32768)

Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn>
Signed-off-by: wangln19 <96399074+wangln19@users.noreply.github.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
wangln19
2026-01-27 10:22:35 +08:00
committed by GitHub
parent 5a93b9162b
commit 2d7053438a
7 changed files with 183 additions and 68 deletions

View File

@@ -42,6 +42,7 @@ class MockModelConfig:
tokenizer_revision = None tokenizer_revision = None
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processor_pattern = None logits_processor_pattern = None
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None

View File

@@ -518,6 +518,7 @@ class MockModelConfig:
tokenizer_revision = None tokenizer_revision = None
multimodal_config = MultiModalConfig() multimodal_config = MultiModalConfig()
hf_config = MockHFConfig() hf_config = MockHFConfig()
hf_text_config = MockHFConfig()
logits_processors: list[str] | None = None logits_processors: list[str] | None = None
logits_processor_pattern = None logits_processor_pattern = None
diff_sampling_param: dict | None = None diff_sampling_param: dict | None = None

View File

@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
DeltaMessage, DeltaMessage,
DeltaToolCall, DeltaToolCall,
ErrorResponse, ErrorResponse,
FunctionCall,
PromptTokenUsageInfo, PromptTokenUsageInfo,
RequestResponseMetadata, RequestResponseMetadata,
ToolCall, ToolCall,
@@ -143,11 +144,6 @@ class OpenAIServingChat(OpenAIServing):
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = self.model_config.get_diff_sampling_param() self.default_sampling_params = self.model_config.get_diff_sampling_param()
if self.model_config.hf_config.model_type == "kimi_k2":
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss" self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony: if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params: if "stop_token_ids" not in self.default_sampling_params:
@@ -156,6 +152,16 @@ class OpenAIServingChat(OpenAIServing):
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
# NOTE(woosuk): While OpenAI's chat completion API supports browsing # NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the # for some models, currently vLLM doesn't support it. Please use the
# Responses API instead. # Responses API instead.
@@ -247,8 +253,8 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request) # type: ignore[arg-type]
truncate_tool_call_ids(request) truncate_tool_call_ids(request) # type: ignore[arg-type]
validate_request_params(request) validate_request_params(request)
# Check if tool parsing is unavailable (common condition) # Check if tool parsing is unavailable (common condition)
@@ -454,6 +460,7 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
@@ -632,9 +639,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk" chunk_object_type: Final = "chat.completion.chunk"
first_iteration = True first_iteration = True
@@ -698,7 +707,7 @@ class OpenAIServingChat(OpenAIServing):
) )
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
) )
except RuntimeError as e: except RuntimeError as e:
logger.exception("Error in reasoning parser creation.") logger.exception("Error in reasoning parser creation.")
@@ -955,8 +964,17 @@ class OpenAIServingChat(OpenAIServing):
index=i, index=i,
) )
else: else:
# Generate ID based on tokenizer type
if isinstance(tokenizer, MistralTokenizer):
tool_call_id = MistralToolCall.generate_random_id()
else:
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_choice_function_name,
idx=history_tool_call_cnt,
)
delta_tool_call = DeltaToolCall( delta_tool_call = DeltaToolCall(
id=make_tool_call_id(), id=tool_call_id,
type="function", type="function",
function=DeltaFunctionCall( function=DeltaFunctionCall(
name=tool_choice_function_name, name=tool_choice_function_name,
@@ -1387,9 +1405,11 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
from vllm.tokenizers.mistral import MistralTokenizer
created_time = int(time.time()) created_time = int(time.time())
final_res: RequestOutput | None = None final_res: RequestOutput | None = None
@@ -1524,39 +1544,85 @@ class OpenAIServingChat(OpenAIServing):
tool_call_class = ( tool_call_class = (
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
) )
if (not self.enable_auto_tools or not self.tool_parser) and ( if self.use_harmony:
# Harmony models already have parsed content and tool_calls
# through parse_chat_output. Respect its output directly.
message = ChatMessage(
role=role,
reasoning=reasoning,
content=content,
tool_calls=tool_calls if tool_calls else [],
)
elif (not self.enable_auto_tools or not self.tool_parser) and (
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam) not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
and request.tool_choice != "required" and request.tool_choice != "required"
): ):
message = ChatMessage(role=role, reasoning=reasoning, content=content) message = ChatMessage(role=role, reasoning=reasoning, content=content)
# if the request uses tools and specified a tool choice
elif ( elif (
request.tool_choice request.tool_choice
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
): ):
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
tool_call_class_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_class_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_class_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content="", content="",
tool_calls=[tool_call_class(function=tc) for tc in tool_calls], tool_calls=tool_call_class_items,
) )
elif request.tool_choice and request.tool_choice == "required": elif request.tool_choice and request.tool_choice == "required":
tool_call_class_items = [] tool_call_class_items = []
assert tool_calls is not None and len(tool_calls) > 0 assert tool_calls is not None and len(tool_calls) > 0
for tool_call in tool_calls: for idx, tool_call in enumerate(tool_calls):
tool_call_class_items.append( # Use native ID if available,
tool_call_class( # otherwise generate ID with correct id_type
id=make_tool_call_id( if tool_call.id:
tool_call_class_items.append(
tool_call_class(id=tool_call.id, function=tool_call)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_class_items.append(
tool_call_class(function=tool_call)
)
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type, id_type=self.tool_call_id_type,
func_name=tool_call.name, func_name=tool_call.name,
idx=history_tool_call_cnt, idx=history_tool_call_cnt + idx,
), )
function=tool_call, tool_call_class_items.append(
) tool_call_class(id=generated_id, function=tool_call)
) )
history_tool_call_cnt += 1 history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
@@ -1582,17 +1648,35 @@ class OpenAIServingChat(OpenAIServing):
# call. The same is not true for named function calls # call. The same is not true for named function calls
auto_tools_called = tool_calls is not None and len(tool_calls) > 0 auto_tools_called = tool_calls is not None and len(tool_calls) > 0
if tool_calls: if tool_calls:
tool_call_items = []
for idx, tc in enumerate(tool_calls):
# Use native ID if available (e.g., Kimi K2),
# otherwise generate ID with correct id_type
if tc.id:
tool_call_items.append(
tool_call_class(id=tc.id, function=tc)
)
else:
# Generate ID using the correct format (kimi_k2 or random),
# but leave it to the class if it's Mistral to preserve
# 9-char IDs
if isinstance(tokenizer, MistralTokenizer):
tool_call_items.append(tool_call_class(function=tc))
else:
generated_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tc.name,
idx=history_tool_call_cnt + idx,
)
tool_call_items.append(
tool_call_class(id=generated_id, function=tc)
)
history_tool_call_cnt += 1
message = ChatMessage( message = ChatMessage(
role=role, role=role,
reasoning=reasoning, reasoning=reasoning,
content=content, content=content,
tool_calls=[ tool_calls=tool_call_items,
ToolCall(
function=tc,
type="function",
)
for tc in tool_calls
],
) )
else: else:
@@ -1701,13 +1785,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls: elif choice.message.tool_calls:
# For tool calls, log the function name and arguments # For tool calls, log the function name and arguments
tool_call_descriptions = [] tool_call_descriptions = []
for tc in choice.message.tool_calls: for tc in choice.message.tool_calls: # type: ignore
if hasattr(tc.function, "name") and hasattr( function_call: FunctionCall = tc.function # type: ignore
tc.function, "arguments" tool_call_descriptions.append(
): f"{function_call.name}({function_call.arguments})"
tool_call_descriptions.append( )
f"{tc.function.name}({tc.function.arguments})"
)
tool_calls_str = ", ".join(tool_call_descriptions) tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]" output_text = f"[tool_calls: {tool_calls_str}]"
@@ -1895,7 +1977,7 @@ class OpenAIServingChat(OpenAIServing):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls` # for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request) maybe_serialize_tool_calls(request) # type: ignore[arg-type]
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
@@ -1913,7 +1995,7 @@ class OpenAIServingChat(OpenAIServing):
# Add developer message. # Add developer message.
if request.tools: if request.tools:
dev_msg = get_developer_message( dev_msg = get_developer_message(
tools=request.tools if should_include_tools else None tools=request.tools if should_include_tools else None # type: ignore[arg-type]
) )
messages.append(dev_msg) messages.append(dev_msg)

View File

@@ -218,6 +218,10 @@ def get_logits_processors(
class FunctionCall(OpenAIBaseModel): class FunctionCall(OpenAIBaseModel):
# Internal field to preserve native tool call ID from tool parser.
# Excluded from serialization to maintain OpenAI API compatibility
# (function object should only contain 'name' and 'arguments').
id: str | None = Field(default=None, exclude=True)
name: str name: str
arguments: str arguments: str

View File

@@ -1525,6 +1525,7 @@ class OpenAIServing:
# extract_tool_calls() returns a list of tool calls. # extract_tool_calls() returns a list of tool calls.
function_calls.extend( function_calls.extend(
FunctionCall( FunctionCall(
id=tool_call.id,
name=tool_call.function.name, name=tool_call.function.name,
arguments=tool_call.function.arguments, arguments=tool_call.function.arguments,
) )

View File

@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
make_tool_call_id,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
@@ -250,6 +251,17 @@ class OpenAIServingResponses(OpenAIServing):
self.default_sampling_params["stop_token_ids"].extend( self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# set up tool use # set up tool use
self.tool_parser = self._get_tool_parser( self.tool_parser = self._get_tool_parser(
@@ -954,25 +966,28 @@ class OpenAIServingResponses(OpenAIServing):
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_parser_cls=self.tool_parser, tool_parser_cls=self.tool_parser,
) )
if content:
output_text = ResponseOutputText( if content or (self.use_harmony and tool_calls):
text=content, res_text_part = None
annotations=[], # TODO if content:
type="output_text", res_text_part = ResponseOutputText(
logprobs=( text=content,
self._create_response_logprobs( annotations=[], # TODO
token_ids=final_output.token_ids, type="output_text",
logprobs=final_output.logprobs, logprobs=(
tokenizer=tokenizer, self._create_response_logprobs(
top_logprobs=request.top_logprobs, token_ids=final_output.token_ids,
) logprobs=final_output.logprobs,
if request.is_include_output_logprobs() tokenizer=tokenizer,
else None top_logprobs=request.top_logprobs,
), )
) if request.is_include_output_logprobs()
else None
),
)
message_item = ResponseOutputMessage( message_item = ResponseOutputMessage(
id=f"msg_{random_uuid()}", id=f"msg_{random_uuid()}",
content=[output_text], content=[res_text_part] if res_text_part else [],
role="assistant", role="assistant",
status="completed", status="completed",
type="message", type="message",
@@ -984,17 +999,28 @@ class OpenAIServingResponses(OpenAIServing):
if message_item: if message_item:
outputs.append(message_item) outputs.append(message_item)
if tool_calls: if tool_calls:
tool_call_items = [ # We use a simple counter for history_tool_call_count because
ResponseFunctionToolCall( # we don't track the history of tool calls in the Responses API yet.
id=f"fc_{random_uuid()}", # This means that the tool call index will start from 0 for each
call_id=f"call_{random_uuid()}", # request.
type="function_call", tool_call_items = []
status="completed", for history_tool_call_cnt, tool_call in enumerate(tool_calls):
name=tool_call.name, tool_call_items.append(
arguments=tool_call.arguments, ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=tool_call.id
if tool_call.id
else make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt,
),
type="function_call",
status="completed",
name=tool_call.name,
arguments=tool_call.arguments,
)
) )
for tool_call in tool_calls
]
outputs.extend(tool_call_items) outputs.extend(tool_call_items)
return outputs return outputs

View File

@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_matches: if current_tool_call_matches:
tool_id, tool_args = current_tool_call_matches.groups() tool_id, tool_args = current_tool_call_matches.groups()
tool_name = tool_id.split(":")[0].split(".")[-1] tool_name = tool_id.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id current_tool_call["id"] = tool_id.strip()
current_tool_call["name"] = tool_name current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args current_tool_call["arguments"] = tool_args
else: else:
@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser):
if current_tool_call_name_matches: if current_tool_call_name_matches:
(tool_id_str,) = current_tool_call_name_matches.groups() (tool_id_str,) = current_tool_call_name_matches.groups()
tool_name = tool_id_str.split(":")[0].split(".")[-1] tool_name = tool_id_str.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id_str current_tool_call["id"] = tool_id_str.strip()
current_tool_call["name"] = tool_name current_tool_call["name"] = tool_name
current_tool_call["arguments"] = "" current_tool_call["arguments"] = ""
else: else: