[Tool parsing] Improve / correct mistral tool parsing (#10333)

This commit is contained in:
Patrick von Platen
2024-11-15 01:42:49 +01:00
committed by GitHub
parent 554af9228d
commit 11cd1ae6ad
5 changed files with 172 additions and 59 deletions

View File

@@ -30,6 +30,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation
logger = init_logger(__name__)
@@ -127,41 +128,11 @@ class OpenAIServingChat(OpenAIServing):
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get(
"tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(
tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i][
"tool_calls"] = validated_tool_calls
maybe_serialize_tool_calls(request)
if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)

View File

@@ -62,7 +62,7 @@ class MistralToolParser(ToolParser):
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
@@ -84,16 +84,25 @@ class MistralToolParser(ToolParser):
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)
# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
@@ -116,7 +125,7 @@ class MistralToolParser(ToolParser):
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
content=tool_content)
def extract_tool_calls_streaming(
self,