fix: preserve native tool call ID in multi-turn tool calling (#32768)
Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn> Signed-off-by: wangln19 <96399074+wangln19@users.noreply.github.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -42,6 +42,7 @@ class MockModelConfig:
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
hf_text_config = MockHFConfig()
|
||||
logits_processor_pattern = None
|
||||
logits_processors: list[str] | None = None
|
||||
diff_sampling_param: dict | None = None
|
||||
|
||||
@@ -518,6 +518,7 @@ class MockModelConfig:
|
||||
tokenizer_revision = None
|
||||
multimodal_config = MultiModalConfig()
|
||||
hf_config = MockHFConfig()
|
||||
hf_text_config = MockHFConfig()
|
||||
logits_processors: list[str] | None = None
|
||||
logits_processor_pattern = None
|
||||
diff_sampling_param: dict | None = None
|
||||
|
||||
@@ -44,6 +44,7 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ErrorResponse,
|
||||
FunctionCall,
|
||||
PromptTokenUsageInfo,
|
||||
RequestResponseMetadata,
|
||||
ToolCall,
|
||||
@@ -143,11 +144,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
self.enable_force_include_usage = enable_force_include_usage
|
||||
self.default_sampling_params = self.model_config.get_diff_sampling_param()
|
||||
if self.model_config.hf_config.model_type == "kimi_k2":
|
||||
self.tool_call_id_type = "kimi_k2"
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
self.use_harmony = self.model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
if "stop_token_ids" not in self.default_sampling_params:
|
||||
@@ -156,6 +152,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
get_stop_tokens_for_assistant_actions()
|
||||
)
|
||||
|
||||
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
|
||||
hf_overrides = getattr(self.model_config, "hf_overrides", None)
|
||||
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
|
||||
isinstance(hf_overrides, dict)
|
||||
and hf_overrides.get("model_type") == "kimi_k2"
|
||||
):
|
||||
self.tool_call_id_type = "kimi_k2"
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
|
||||
# for some models, currently vLLM doesn't support it. Please use the
|
||||
# Responses API instead.
|
||||
@@ -247,8 +253,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
truncate_tool_call_ids(request)
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
truncate_tool_call_ids(request) # type: ignore[arg-type]
|
||||
validate_request_params(request)
|
||||
|
||||
# Check if tool parsing is unavailable (common condition)
|
||||
@@ -454,6 +460,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
assert tokenizer is not None
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
@@ -632,9 +639,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: TokenizerLike | None,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
created_time = int(time.time())
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
@@ -698,7 +707,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
reasoning_parser = self.reasoning_parser(
|
||||
tokenizer,
|
||||
chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg]
|
||||
chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg]
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
@@ -955,8 +964,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i,
|
||||
)
|
||||
else:
|
||||
# Generate ID based on tokenizer type
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_id = MistralToolCall.generate_random_id()
|
||||
else:
|
||||
tool_call_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_choice_function_name,
|
||||
idx=history_tool_call_cnt,
|
||||
)
|
||||
delta_tool_call = DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
@@ -1387,9 +1405,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request_id: str,
|
||||
model_name: str,
|
||||
conversation: list[ConversationMessage],
|
||||
tokenizer: TokenizerLike | None,
|
||||
tokenizer: TokenizerLike,
|
||||
request_metadata: RequestResponseMetadata,
|
||||
) -> ErrorResponse | ChatCompletionResponse:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
created_time = int(time.time())
|
||||
final_res: RequestOutput | None = None
|
||||
|
||||
@@ -1524,39 +1544,85 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tool_call_class = (
|
||||
MistralToolCall if isinstance(tokenizer, MistralTokenizer) else ToolCall
|
||||
)
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and (
|
||||
if self.use_harmony:
|
||||
# Harmony models already have parsed content and tool_calls
|
||||
# through parse_chat_output. Respect its output directly.
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content=content,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
)
|
||||
|
||||
elif (not self.enable_auto_tools or not self.tool_parser) and (
|
||||
not isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
|
||||
and request.tool_choice != "required"
|
||||
):
|
||||
message = ChatMessage(role=role, reasoning=reasoning, content=content)
|
||||
|
||||
# if the request uses tools and specified a tool choice
|
||||
elif (
|
||||
request.tool_choice
|
||||
and type(request.tool_choice) is ChatCompletionNamedToolChoiceParam
|
||||
):
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
tool_call_class_items = []
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
# Use native ID if available (e.g., Kimi K2),
|
||||
# otherwise generate ID with correct id_type
|
||||
if tc.id:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=tc.id, function=tc)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_class_items.append(tool_call_class(function=tc))
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tc.name,
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=generated_id, function=tc)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content="",
|
||||
tool_calls=[tool_call_class(function=tc) for tc in tool_calls],
|
||||
tool_calls=tool_call_class_items,
|
||||
)
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class_items = []
|
||||
assert tool_calls is not None and len(tool_calls) > 0
|
||||
for tool_call in tool_calls:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(
|
||||
id=make_tool_call_id(
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
# Use native ID if available,
|
||||
# otherwise generate ID with correct id_type
|
||||
if tool_call.id:
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=tool_call.id, function=tool_call)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(function=tool_call)
|
||||
)
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt,
|
||||
),
|
||||
function=tool_call,
|
||||
)
|
||||
)
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_class_items.append(
|
||||
tool_call_class(id=generated_id, function=tool_call)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
@@ -1582,17 +1648,35 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# call. The same is not true for named function calls
|
||||
auto_tools_called = tool_calls is not None and len(tool_calls) > 0
|
||||
if tool_calls:
|
||||
tool_call_items = []
|
||||
for idx, tc in enumerate(tool_calls):
|
||||
# Use native ID if available (e.g., Kimi K2),
|
||||
# otherwise generate ID with correct id_type
|
||||
if tc.id:
|
||||
tool_call_items.append(
|
||||
tool_call_class(id=tc.id, function=tc)
|
||||
)
|
||||
else:
|
||||
# Generate ID using the correct format (kimi_k2 or random),
|
||||
# but leave it to the class if it's Mistral to preserve
|
||||
# 9-char IDs
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
tool_call_items.append(tool_call_class(function=tc))
|
||||
else:
|
||||
generated_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tc.name,
|
||||
idx=history_tool_call_cnt + idx,
|
||||
)
|
||||
tool_call_items.append(
|
||||
tool_call_class(id=generated_id, function=tc)
|
||||
)
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning=reasoning,
|
||||
content=content,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=tc,
|
||||
type="function",
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
tool_calls=tool_call_items,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -1701,13 +1785,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
elif choice.message.tool_calls:
|
||||
# For tool calls, log the function name and arguments
|
||||
tool_call_descriptions = []
|
||||
for tc in choice.message.tool_calls:
|
||||
if hasattr(tc.function, "name") and hasattr(
|
||||
tc.function, "arguments"
|
||||
):
|
||||
tool_call_descriptions.append(
|
||||
f"{tc.function.name}({tc.function.arguments})"
|
||||
)
|
||||
for tc in choice.message.tool_calls: # type: ignore
|
||||
function_call: FunctionCall = tc.function # type: ignore
|
||||
tool_call_descriptions.append(
|
||||
f"{function_call.name}({function_call.arguments})"
|
||||
)
|
||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||
|
||||
@@ -1895,7 +1977,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
maybe_serialize_tool_calls(request) # type: ignore[arg-type]
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
@@ -1913,7 +1995,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Add developer message.
|
||||
if request.tools:
|
||||
dev_msg = get_developer_message(
|
||||
tools=request.tools if should_include_tools else None
|
||||
tools=request.tools if should_include_tools else None # type: ignore[arg-type]
|
||||
)
|
||||
messages.append(dev_msg)
|
||||
|
||||
|
||||
@@ -218,6 +218,10 @@ def get_logits_processors(
|
||||
|
||||
|
||||
class FunctionCall(OpenAIBaseModel):
|
||||
# Internal field to preserve native tool call ID from tool parser.
|
||||
# Excluded from serialization to maintain OpenAI API compatibility
|
||||
# (function object should only contain 'name' and 'arguments').
|
||||
id: str | None = Field(default=None, exclude=True)
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@@ -1525,6 +1525,7 @@ class OpenAIServing:
|
||||
# extract_tool_calls() returns a list of tool calls.
|
||||
function_calls.extend(
|
||||
FunctionCall(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
|
||||
@@ -63,6 +63,7 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
make_tool_call_id,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
@@ -250,6 +251,17 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self.default_sampling_params["stop_token_ids"].extend(
|
||||
get_stop_tokens_for_assistant_actions()
|
||||
)
|
||||
|
||||
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
|
||||
hf_overrides = getattr(self.model_config, "hf_overrides", None)
|
||||
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
|
||||
isinstance(hf_overrides, dict)
|
||||
and hf_overrides.get("model_type") == "kimi_k2"
|
||||
):
|
||||
self.tool_call_id_type = "kimi_k2"
|
||||
else:
|
||||
self.tool_call_id_type = "random"
|
||||
|
||||
self.enable_auto_tools = enable_auto_tools
|
||||
# set up tool use
|
||||
self.tool_parser = self._get_tool_parser(
|
||||
@@ -954,25 +966,28 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
enable_auto_tools=self.enable_auto_tools,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
)
|
||||
if content:
|
||||
output_text = ResponseOutputText(
|
||||
text=content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=(
|
||||
self._create_response_logprobs(
|
||||
token_ids=final_output.token_ids,
|
||||
logprobs=final_output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if request.is_include_output_logprobs()
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if content or (self.use_harmony and tool_calls):
|
||||
res_text_part = None
|
||||
if content:
|
||||
res_text_part = ResponseOutputText(
|
||||
text=content,
|
||||
annotations=[], # TODO
|
||||
type="output_text",
|
||||
logprobs=(
|
||||
self._create_response_logprobs(
|
||||
token_ids=final_output.token_ids,
|
||||
logprobs=final_output.logprobs,
|
||||
tokenizer=tokenizer,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if request.is_include_output_logprobs()
|
||||
else None
|
||||
),
|
||||
)
|
||||
message_item = ResponseOutputMessage(
|
||||
id=f"msg_{random_uuid()}",
|
||||
content=[output_text],
|
||||
content=[res_text_part] if res_text_part else [],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
@@ -984,17 +999,28 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if message_item:
|
||||
outputs.append(message_item)
|
||||
if tool_calls:
|
||||
tool_call_items = [
|
||||
ResponseFunctionToolCall(
|
||||
id=f"fc_{random_uuid()}",
|
||||
call_id=f"call_{random_uuid()}",
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name=tool_call.name,
|
||||
arguments=tool_call.arguments,
|
||||
# We use a simple counter for history_tool_call_count because
|
||||
# we don't track the history of tool calls in the Responses API yet.
|
||||
# This means that the tool call index will start from 0 for each
|
||||
# request.
|
||||
tool_call_items = []
|
||||
for history_tool_call_cnt, tool_call in enumerate(tool_calls):
|
||||
tool_call_items.append(
|
||||
ResponseFunctionToolCall(
|
||||
id=f"fc_{random_uuid()}",
|
||||
call_id=tool_call.id
|
||||
if tool_call.id
|
||||
else make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt,
|
||||
),
|
||||
type="function_call",
|
||||
status="completed",
|
||||
name=tool_call.name,
|
||||
arguments=tool_call.arguments,
|
||||
)
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
outputs.extend(tool_call_items)
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -448,7 +448,7 @@ class KimiK2ToolParser(ToolParser):
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = current_tool_call_matches.groups()
|
||||
tool_name = tool_id.split(":")[0].split(".")[-1]
|
||||
current_tool_call["id"] = tool_id
|
||||
current_tool_call["id"] = tool_id.strip()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
@@ -458,7 +458,7 @@ class KimiK2ToolParser(ToolParser):
|
||||
if current_tool_call_name_matches:
|
||||
(tool_id_str,) = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split(":")[0].split(".")[-1]
|
||||
current_tool_call["id"] = tool_id_str
|
||||
current_tool_call["id"] = tool_id_str.strip()
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user