[Frontend] Generate valid tool call IDs when using tokenizer-mode=mistral (#12332)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids)
|
||||
|
||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||
__all__ = [
|
||||
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
|
||||
]
|
||||
|
||||
@@ -68,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||
"""Truncates tool call IDs for Mistral's ID requirements."""
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
for tool_call in tool_calls:
|
||||
if len(tool_call["id"]) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool call ID: %s to %s",
|
||||
tool_call["id"],
|
||||
tool_call["id"][-9:],
|
||||
)
|
||||
tool_call["id"] = tool_call["id"][-9:]
|
||||
|
||||
request.messages[i]["tool_calls"] = tool_calls
|
||||
|
||||
elif message.get("role") in {"tool_results", "tool"}:
|
||||
if "tool_call_id" in message:
|
||||
tool_call_id = message["tool_call_id"]
|
||||
|
||||
if len(tool_call_id) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool_call_id: %s to %s",
|
||||
tool_call_id,
|
||||
tool_call_id[-9:],
|
||||
)
|
||||
tool_call_id = tool_call_id[-9:]
|
||||
request.messages[i]["tool_call_id"] = tool_call_id
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
|
||||
Reference in New Issue
Block a user