fix: preserve native tool call ID in multi-turn tool calling (#32768)
Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn> Signed-off-by: wangln19 <96399074+wangln19@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
ToolCall,
|
||||
@@ -143,11 +144,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
if self.model_config.hf_config.model_type == "kimi_k2":
|
||||
self.tool_call_id_type = "kimi_k2"
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
if "stop_token_ids" not in self.default_sampling_params:
|
||||
@@ -156,6 +152,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
get_stop_tokens_for_assistant_actions()
|
||||
)
|
||||
|
||||
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
|
||||
hf_overrides = getattr(self.model_config, "hf_overrides", None)
|
||||
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
|
||||
isinstance(hf_overrides, dict)
|
||||
and hf_overrides.get("model_type") == "kimi_k2"
|
||||
):
|
||||
self.tool_call_id_type = "kimi_k2"
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
|
||||
# for some models, currently vLLM doesn't support it. Please use the
|
||||
# Responses API instead.
|
||||
@@ -247,8 +253,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
truncate_tool_call_ids(request)
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
truncate_tool_call_ids(request) # type: ignore[arg-type]
|
||||
validate_request_params(request)
|
||||
|
||||
# Check if tool parsing is unavailable (common condition)
|
||||
@@ -454,6 +460,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
assert tokenizer is not None
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
@@ -632,9 +639,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: TokenizerLike | None,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
@@ -698,7 +707,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
reasoning_parser = self.reasoning_parser(
|
||||
tokenizer,
|
||||
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
|
||||
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
@@ -955,8 +964,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i,
|
||||
)
|
||||
else:
|
||||
# Generate ID based on tokenizer type
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_id = MistralToolCall.generate_random_id()
|
||||
else:
|
||||
tool_call_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_choice_function_name,
|
||||
idx=history_tool_call_cnt,
|
||||
)
|
||||
delta_tool_call = DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
@@ -1387,9 +1405,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: TokenizerLike | None,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | ChatCompletionResponse:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput | None = None
|
||||
|
||||
@@ -1524,39 +1544,85 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_call_class = (
|
||||
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
|
||||
)
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and (
|
||||
if self.use_harmony:
|
||||
# Harmony models already have parsed content and tool_calls
|
||||
# through parse_chat_output. Respect its output directly.
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content=content,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
)
|
||||
|
||||
elif (not self.enable_auto_tools or not self.tool_parser) and (
|
||||
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
|
||||
and request.tool_choice != "required"
|
||||
):
|
||||
message = ChatMessage(role=role, reasoning=reasoning, content=content)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif (
|
||||
request.tool_choice
|
||||
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
tool_call_class_items = []
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
# Use native ID if available (e.g., Kimi K2),
|
||||
# otherwise generate ID with correct id_type
|
||||
if tc.id:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=tc.id, function=tc)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_class_items.append(tool_call_class(function=tc))
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tc.name,
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=generated_id, function=tc)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content="",
|
||||
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
|
||||
tool_calls=tool_call_class_items,
|
||||
)
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class_items = []
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
for tool_call in tool_calls:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(
|
||||
id=make_tool_call_id(
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
# Use native ID if available,
|
||||
# otherwise generate ID with correct id_type
|
||||
if tool_call.id:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=tool_call.id, function=tool_call)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(function=tool_call)
|
||||
)
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt,
|
||||
),
|
||||
function=tool_call,
|
||||
)
|
||||
)
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=generated_id, function=tool_call)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
@@ -1582,17 +1648,35 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
|
||||
if tool_calls:
|
||||
tool_call_items = []
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
# Use native ID if available (e.g., Kimi K2),
|
||||
# otherwise generate ID with correct id_type
|
||||
if tc.id:
|
||||
tool_call_items.append(
|
||||
tool_call_class(id=tc.id, function=tc)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_items.append(tool_call_class(function=tc))
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tc.name,
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_items.append(
|
||||
tool_call_class(id=generated_id, function=tc)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content=content,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=tc,
|
||||
type="function",
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
tool_calls=tool_call_items,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -1701,13 +1785,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
elif choice.message.tool_calls:
|
||||
# For tool calls, log the function name and arguments
|
||||
tool_call_descriptions = []
|
||||
for tc in choice.message.tool_calls:
|
||||
if hasattr(tc.function, "name") and hasattr(
|
||||
tc.function, "arguments"
|
||||
):
|
||||
tool_call_descriptions.append(
|
||||
f"{tc.function.name}({tc.function.arguments})"
|
||||
)
|
||||
for tc in choice.message.tool_calls: # type: ignore
|
||||
function_call: FunctionCall = tc.function # type: ignore
|
||||
tool_call_descriptions.append(
|
||||
f"{function_call.name}({function_call.arguments})"
|
||||
)
|
||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||
|
||||
@@ -1895,7 +1977,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
@@ -1913,7 +1995,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Add developer message.
|
||||
if request.tools:
|
||||
dev_msg = get_developer_message(
|
||||
tools=request.tools if should_include_tools else None
|
||||
tools=request.tools if should_include_tools else None # type: ignore[arg-type]
|
||||
)
|
||||
messages.append(dev_msg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user