[Tool parsing] Improve / correct mistral tool parsing (#10333)

This commit is contained in:
Patrick von Platen
2024-11-15 01:42:49 +01:00
committed by GitHub
parent 554af9228d
commit 11cd1ae6ad
5 changed files with 172 additions and 59 deletions

View File

@@ -1,3 +1,3 @@
from .mistral import MistralTokenizer
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
__all__ = ["MistralTokenizer"]
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
@@ -29,6 +30,43 @@ class Encoding:
input_ids: List[int]
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i]["tool_calls"] = validated_tool_calls
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
@@ -222,7 +260,8 @@ class MistralTokenizer:
if self.is_tekken:
tokens = [
t for t in tokens
if t not in self.tokenizer._all_special_tokens
if (t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens)
]
if any(isinstance(t, bytes) for t in tokens):
@@ -246,7 +285,27 @@ class MistralTokenizer:
else:
decoded = "".join(tokens)
else:
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
regular_tokens: List[str] = []
decoded_list = []
for token in tokens:
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)
if regular_tokens:
decoded_list.append(
self.decode(regular_tokens)) # type: ignore
decoded = ''.join(decoded_list)
return decoded
@@ -274,8 +333,11 @@ class MistralTokenizer:
assert self.is_tekken or self.is_spm, type(self.tokenizer)
if self.is_tekken:
# skip special tokens
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
# skip special tokens except tool call
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
]
tokens = [self.tokenizer.id_to_piece(id) for id in ids]