[Tool parsing] Improve / correct mistral tool parsing (#10333)
This commit is contained in:
committed by
GitHub
parent
554af9228d
commit
11cd1ae6ad
@@ -1,3 +1,3 @@
|
||||
from .mistral import MistralTokenizer
|
||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||
|
||||
__all__ = ["MistralTokenizer"]
|
||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
# yapf: disable
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
@@ -29,6 +30,43 @@ class Encoding:
|
||||
input_ids: List[int]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||
# Credits go to: @gcalmettes
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
# declared as iterables are replaced in in the instances by
|
||||
# pydantic-core ValidatorIterator instance. In particular, this
|
||||
# affects tool_calls defined in ChatCompletionAssistantMessageParam
|
||||
# model:
|
||||
# see:
|
||||
# - https://github.com/pydantic/pydantic/issues/9467
|
||||
# As a result, tool_calls from assistant messages are never
|
||||
# deserialized in the request object if the tool_calls iterator is
|
||||
# not consumed. This affect messages passed to the MistralTokenizer
|
||||
# since no chat template is applied and therefore the tools_calls
|
||||
# iterator is not directly consumed.
|
||||
# Issue is tracked on Pydantic side, with resolution planned for
|
||||
# v2.11 release. In the meantime, the official workaround is to
|
||||
# consume the iterator so the tool_calls are correctly deserialized
|
||||
# in the OpenAI ChatCompletionAssistantMessageParam object
|
||||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
|
||||
# Official Pydantic Issues:
|
||||
# - https://github.com/pydantic/pydantic/issues/9541
|
||||
# TODO: remove when pydantic v2.11 is released
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls_validator = message.get("tool_calls", ().__iter__())
|
||||
validated_tool_calls = []
|
||||
while True:
|
||||
try:
|
||||
tool_call = next(tool_calls_validator) # type: ignore
|
||||
validated_tool_calls.append(tool_call)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
@@ -222,7 +260,8 @@ class MistralTokenizer:
|
||||
if self.is_tekken:
|
||||
tokens = [
|
||||
t for t in tokens
|
||||
if t not in self.tokenizer._all_special_tokens
|
||||
if (t is SpecialTokens.tool_calls
|
||||
or t not in self.tokenizer._all_special_tokens)
|
||||
]
|
||||
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
@@ -246,7 +285,27 @@ class MistralTokenizer:
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
else:
|
||||
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||
# make sure certain special tokens like Tool calls are
|
||||
# not decoded
|
||||
special_tokens = {SpecialTokens.tool_calls}
|
||||
regular_tokens: List[str] = []
|
||||
decoded_list = []
|
||||
|
||||
for token in tokens:
|
||||
if token in special_tokens:
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens))
|
||||
regular_tokens = []
|
||||
decoded_list.append(token)
|
||||
else:
|
||||
regular_tokens.append(token)
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.decode(regular_tokens)) # type: ignore
|
||||
|
||||
decoded = ''.join(decoded_list)
|
||||
|
||||
return decoded
|
||||
|
||||
@@ -274,8 +333,11 @@ class MistralTokenizer:
|
||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||
|
||||
if self.is_tekken:
|
||||
# skip special tokens
|
||||
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
|
||||
# skip special tokens except tool call
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||
]
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user