[Bugfix] Mistral tool parser streaming update (#19425)

Signed-off-by: avigny <47987522+avigny@users.noreply.github.com>
Signed-off-by: Chauncey <chaunceyjiang@gmail.com>
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
Co-authored-by: Jeff Cook <jeff@jeffcook.io>
Co-authored-by: sfbemerk <benjaminmerkel@mail.de>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
avigny
2025-12-03 18:45:31 +01:00
committed by GitHub
parent d1f7392c5f
commit dd5d1ef780
4 changed files with 1277 additions and 207 deletions

View File

@@ -3,12 +3,12 @@
import json
from collections.abc import Sequence
from enum import Enum, auto
from random import choices
from string import ascii_letters, digits
import partial_json_parser
import ijson
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
@@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike
@@ -32,6 +31,22 @@ logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class StreamingState(Enum):
"""Enum for tracking the current streaming parsing state."""
WAITING_FOR_TOOL_START = auto()
WAITING_FOR_TOOL_KEY = (
auto()
) # waiting for the "name" or "arguments" key to be complete
PARSING_NAME = auto()
PARSING_NAME_COMPLETED = auto()
WAITING_FOR_ARGUMENTS_START = auto()
PARSING_ARGUMENTS = auto()
PARSING_ARGUMENTS_COMPLETED = auto()
TOOL_COMPLETE = auto()
ALL_TOOLS_COMPLETE = auto()
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
@@ -46,8 +61,8 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
@@ -69,16 +84,22 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START
# For streaming pre v11 tokenizer tool calls
self.current_tool_name: str | None = None
self.current_tool_mistral_id: str | None = None
self.starting_new_tool = False
if _is_pre_v11_tokeniser(self.model_tokenizer):
self.parse_coro = ijson.parse_coro(
self.update_stream_state_pre_v11_tokenizer()
)
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
if not _is_pre_v11_tokeniser(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
@@ -131,18 +152,19 @@ class MistralToolParser(ToolParser):
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
for single_tool_content in model_output.split(self.bot_token):
matches = self.fn_name_regex.findall(single_tool_content)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
for match in matches:
fn_name = match[0]
args = match[1]
# fn_name is encoded outside serialized json dump
# only arguments are serialized
function_call_arr.append(
{"name": fn_name, "arguments": json.loads(args)}
)
else:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
@@ -193,198 +215,372 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
if self.bot_token_id not in current_token_ids:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# if the tool call token IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags
if _is_pre_v11_tokeniser(self.model_tokenizer):
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
delta_text=delta_text,
delta_token_ids=delta_token_ids,
)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: str | None = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], ""
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=MistralToolCall.generate_random_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
return self._extract_tool_calls_streaming(
delta_text=delta_text, delta_token_ids=delta_token_ids
)
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("'", '"')
if '"}' in new_text:
new_text = new_text[: new_text.rindex('"}')]
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
)
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[
:-2
]
logger.debug("finding %s in %s", new_text, cur_arguments_json)
if new_text not in cur_arguments_json:
return None
arguments_delta = cur_arguments_json[
: cur_arguments_json.rindex(new_text) + len(new_text)
]
logger.debug(
"First tokens in arguments received: %s", arguments_delta
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
logger.debug(
"Searching for diff between \n%s\n%s",
cur_args_json,
prev_args_json,
)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
)
return None
def _extract_tool_calls_streaming(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
"""
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
delta_text.split(self.bot_token)[1:]
)
delta_tool_calls = self._generate_delta_tool_call(delta_text)
if not additional_content and len(delta_tool_calls) == 0:
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.TOOL_COMPLETE,
StreamingState.ALL_TOOLS_COMPLETE,
]:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage()
else:
# return None when the tool is not likely to be finished
# This can occur when the name is being parsed for example
# and we wait for the name to be complete
# before sending the function name
return None
delta = DeltaMessage()
if additional_content:
delta.content = additional_content
if len(delta_tool_calls) > 0:
delta.tool_calls = delta_tool_calls
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
return delta
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
if delta_text == "" or delta_text is None:
return []
delta_function_name = None
tool_id = None
if self.streaming_state not in [
StreamingState.PARSING_NAME,
StreamingState.PARSING_ARGUMENTS,
] and delta_text.startswith(self.bot_token):
self.current_tool_id += 1
self.streaming_state = StreamingState.PARSING_NAME
delta_text = delta_text.replace(self.bot_token, "", 1)
if self.streaming_state == StreamingState.PARSING_NAME:
if self.current_tool_name is None:
self.current_tool_name = ""
# The name stops where the arguments start
# And the arguments start with the `{` char
if "{" in delta_text:
tool_id = MistralToolCall.generate_random_id()
delta_function_name = delta_text.split("{")[0]
self.current_tool_name += delta_function_name
delta_text = delta_text[len(delta_function_name) :]
self.streaming_state = StreamingState.PARSING_ARGUMENTS
else:
# we want to send the tool name once it's complete
self.current_tool_name += delta_text
return []
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
next_function_text = None
if self.bot_token in delta_text:
# current tool call is over
delta_arguments = ""
delta_arguments += delta_text.split(self.bot_token)[0]
next_function_text = delta_text[len(delta_arguments) :]
self.streaming_state = StreamingState.TOOL_COMPLETE
else:
delta_arguments = delta_text
ret = []
if self.current_tool_name or delta_arguments:
ret += [
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=self.current_tool_name, arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
self.current_tool_name = None
if next_function_text:
ret += self._generate_delta_tool_call(next_function_text)
return ret
# Should not happen
return []
@ijson.coroutine
def update_stream_state_pre_v11_tokenizer(self):
while True:
(prefix, event, value) = yield
if prefix == "item" and event == "start_map":
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if prefix == "item" and event == "map_key" and value == "name":
self.streaming_state = StreamingState.PARSING_NAME
if prefix == "item.name" and event == "string":
self.current_tool_name = value
self.streaming_state = StreamingState.PARSING_NAME_COMPLETED
if prefix == "item" and event == "map_key" and value == "arguments":
self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START
if prefix == "item.arguments" and event == "start_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS
if prefix == "item.arguments" and event == "end_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED
if prefix == "item" and event == "end_map":
self.streaming_state = StreamingState.TOOL_COMPLETE
if prefix == "" and event == "end_array":
self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE
def _extract_tool_calls_streaming_pre_v11_tokenizer(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
"""
assert self.parse_coro is not None
content = None
delta_tool_calls: list[DeltaToolCall] = []
current_tool_call: DeltaToolCall = DeltaToolCall(
index=self.current_tool_id, type="function"
)
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]
delta_text = "".join(delta_text.split(self.bot_token)[1:])
# Cut smartly the delta text to catch the ijson events
# as ijson does not give us the index in the text at each event.
# We need to cut so that we know
# where in the text the events are emitted from.
while len(delta_text) > 0:
streaming_state_before_parse = self.streaming_state
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY:
# Wait until another key is sent
# or the current tool is completed
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_colon=1,
stop_after_opening_curly_braces=1,
# if the tool ends, we want to separate
# at the start of the next tool
)
elif self.streaming_state == StreamingState.PARSING_NAME:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_comma=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
elif self.streaming_state == StreamingState.PARSING_ARGUMENTS:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
# we could be more clever
# by listening to item.arguments.* start_map events
# and know how many curly braces we can allow
)
elif self.streaming_state in [
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.PARSING_NAME_COMPLETED,
]:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.TOOL_COMPLETE:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
content = delta_text
delta_text = ""
else:
delta_to_be_parsed = delta_text
delta_text = ""
if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE:
self.parse_coro.send(delta_to_be_parsed.encode("utf-8"))
# Given the parsed text and the possible streaming state change,
# let's add to the tool delta
if (
(streaming_state_before_parse != self.streaming_state)
and streaming_state_before_parse
in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE]
and self.streaming_state
not in [
StreamingState.ALL_TOOLS_COMPLETE,
StreamingState.TOOL_COMPLETE,
StreamingState.WAITING_FOR_TOOL_START,
]
):
# starting a new tool call
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
current_tool_call_modified = False
self.current_tool_id += 1
self.current_tool_mistral_id = MistralToolCall.generate_random_id()
current_tool_call = DeltaToolCall(
index=self.current_tool_id,
type="function",
)
if current_tool_call.function is None:
current_tool_call.function = DeltaFunctionCall()
if self.current_tool_name is not None:
# we have the complete tool name
current_tool_call_modified = True
current_tool_call.function.name = self.current_tool_name
self.current_tool_name = None
if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
]:
if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
# the delta_to_be_parsed is part of arguments.
current_tool_call_modified = True
if current_tool_call.function.arguments is None:
current_tool_call.function.arguments = delta_to_be_parsed
else:
current_tool_call.function.arguments += delta_to_be_parsed
if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS:
# It's the first chunk of arg. let's lstrip it
current_tool_call.function.arguments = (
current_tool_call.function.arguments.lstrip()
)
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if content or len(delta_tool_calls) > 0:
delta_message = DeltaMessage()
if content:
delta_message.content = content
if len(delta_tool_calls) > 0:
delta_message.tool_calls = delta_tool_calls
return delta_message
else:
if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
return DeltaMessage()
else:
return None
def _split_delta(
self,
delta_text: str,
stop_after_quotes: int = -1,
stop_after_opening_curly_braces: int = -1,
stop_after_closing_curly_braces: int = -1,
stop_after_closing_brackets: int = -1,
stop_after_colon: int = -1,
stop_after_comma=-1,
) -> tuple[str, str]:
delta_to_be_parsed = ""
for i, c in enumerate(delta_text):
if c in ['"', "'"]:
delta_to_be_parsed += c
stop_after_quotes -= 1
if stop_after_quotes == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "{":
delta_to_be_parsed += c
stop_after_opening_curly_braces -= 1
if stop_after_opening_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "}":
delta_to_be_parsed += c
stop_after_closing_curly_braces -= 1
if stop_after_closing_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "]":
delta_to_be_parsed += c
stop_after_closing_brackets -= 1
if stop_after_closing_brackets == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ":":
delta_to_be_parsed += c
stop_after_colon -= 1
if stop_after_colon == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ",":
delta_to_be_parsed += c
stop_after_comma -= 1
if stop_after_comma == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
else:
delta_to_be_parsed += c
return (delta_to_be_parsed, "")