[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral (#12332)

This commit is contained in:
Rafael Vasquez
2025-02-12 11:29:56 -05:00
committed by GitHub
parent 985b4a2b19
commit 314cfade02
8 changed files with 149 additions and 8 deletions

View File

@@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
truncate_tool_call_ids)
logger = init_logger(__name__)
@@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
truncate_tool_call_ids(request)
if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
@@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
elif request.tool_choice and type(
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
message = ChatMessage(
role=role,
content="",
tool_calls=[
ToolCall(function=FunctionCall(
tool_call_class(function=FunctionCall(
name=request.tool_choice.function.name,
arguments=output.text))
])

View File

@@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))

View File

@@ -1,5 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids)
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
__all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
]

View File

@@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls = message.get("tool_calls", [])
for tool_call in tool_calls:
if len(tool_call["id"]) > 9:
logger.warning(
"Truncating tool call ID: %s to %s",
tool_call["id"],
tool_call["id"][-9:],
)
tool_call["id"] = tool_call["id"][-9:]
request.messages[i]["tool_calls"] = tool_calls
elif message.get("role") in {"tool_results", "tool"}:
if "tool_call_id" in message:
tool_call_id = message["tool_call_id"]
if len(tool_call_id) > 9:
logger.warning(
"Truncating tool_call_id: %s to %s",
tool_call_id,
tool_call_id[-9:],
)
tool_call_id = tool_call_id[-9:]
request.messages[i]["tool_call_id"] = tool_call_id
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,