Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -10,13 +10,19 @@ import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
ToolParser,
ToolParserManager,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -25,37 +31,41 @@ logger = init_logger(__name__)
@ToolParserManager.register_module("hermes")
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
logger.error(
"Detected Mistral tokenizer when using a Hermes model")
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = [
] # map what has been streamed for each tool so far to a list
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
)
self.scratch_pad_regex = re.compile(
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
"constructor during construction."
)
self.tool_call_start_token_ids = self.model_tokenizer.encode(
self.tool_call_start_token, add_special_tokens=False)
self.tool_call_start_token, add_special_tokens=False
)
self.tool_call_end_token_ids = self.model_tokenizer.encode(
self.tool_call_end_token, add_special_tokens=False)
self.tool_call_end_token, add_special_tokens=False
)
self.tool_call_start_token_array = [
self.model_tokenizer.decode([token_id])
@@ -77,13 +87,17 @@ class Hermes2ProToolParser(ToolParser):
def tool_call_delta_buffer(self, delta_text: str):
# If the sequence of tool_call_start or tool_call_end tokens is not yet
# complete, fill the buffer with the token and return "".
if (delta_text in self.tool_call_start_token_array
or delta_text in self.tool_call_end_token_array):
if (
delta_text in self.tool_call_start_token_array
or delta_text in self.tool_call_end_token_array
):
# If delta_text is the last token of tool_call_start_token or
# tool_call_end_token, empty the buffer and return
# the buffered text + delta_text.
if (delta_text == self.tool_call_start_token_array[-1]
or delta_text == self.tool_call_end_token_array[-1]):
if (
delta_text == self.tool_call_start_token_array[-1]
or delta_text == self.tool_call_end_token_array[-1]
):
buffered_text = self.buffered_delta_text
self.buffered_delta_text = ""
return buffered_text + delta_text
@@ -98,9 +112,8 @@ class Hermes2ProToolParser(ToolParser):
else:
return delta_text
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != "none":
# do not skip special tokens because the tool_call tokens are
# marked "special" in some models. Since they are skipped
# prior to the call to the tool parser, it breaks tool calling.
@@ -112,22 +125,19 @@ class Hermes2ProToolParser(ToolParser):
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = (
self.tool_call_regex.findall(model_output))
function_call_tuples = self.tool_call_regex.findall(model_output)
# load the JSON, and then use it to build the Function and
# Tool Call
@@ -141,24 +151,26 @@ class Hermes2ProToolParser(ToolParser):
function=FunctionCall(
name=function_call["name"],
# function call args are JSON but as a string
arguments=json.dumps(function_call["arguments"],
ensure_ascii=False)))
arguments=json.dumps(
function_call["arguments"], ensure_ascii=False
),
),
)
for function_call in raw_function_calls
]
content = model_output[:model_output.
find(self.tool_call_start_token)]
content = model_output[: model_output.find(self.tool_call_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
content=content if content else None,
)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_streaming(
self,
@@ -177,10 +189,12 @@ class Hermes2ProToolParser(ToolParser):
delta_text = self.tool_call_delta_buffer(delta_text)
# If the last characters of previous_text
# match self.buffered_delta_text, remove only the matching part.
if (len(previous_text) >= len(self.buffered_delta_text)
and previous_text[-len(self.buffered_delta_text):]
== self.buffered_delta_text):
previous_text = previous_text[:-len(self.buffered_delta_text)]
if (
len(previous_text) >= len(self.buffered_delta_text)
and previous_text[-len(self.buffered_delta_text) :]
== self.buffered_delta_text
):
previous_text = previous_text[: -len(self.buffered_delta_text)]
current_text = previous_text + delta_text
logger.debug("delta_text: %s", delta_text)
@@ -191,50 +205,51 @@ class Hermes2ProToolParser(ToolParser):
return DeltaMessage(content=delta_text)
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_text.count(
self.tool_call_start_token)
prev_tool_start_count = previous_text.count(self.tool_call_start_token)
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
cur_tool_start_count = current_text.count(
self.tool_call_start_token)
cur_tool_start_count = current_text.count(self.tool_call_start_token)
cur_tool_end_count = current_text.count(self.tool_call_end_token)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case: if tool open & close tag counts don't match, we're doing
# imaginary "else" block here
# something with tools with this diff.
# flags for partial JSON parting. exported constants from
# "Allow" are handled via BIT MASK
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
@@ -248,42 +263,49 @@ class Hermes2ProToolParser(ToolParser):
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count):
if (self.prev_tool_call_arr is None
or len(self.prev_tool_call_arr) == 0):
logger.debug(
"attempting to close tool call, but no tool call")
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff
if ('"}' not in delta_text):
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
self.streamed_args_for_tool[self.current_tool_id] \
+= diff
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
@@ -293,13 +315,14 @@ class Hermes2ProToolParser(ToolParser):
return delta
try:
current_tool_call = partial_json_parser.loads(
tool_call_portion or "{}",
flags) if tool_call_portion else None
current_tool_call = (
partial_json_parser.loads(tool_call_portion or "{}", flags)
if tool_call_portion
else None
)
logger.debug("Parsed tool call %s", current_tool_call)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
logger.debug("not enough tokens to parse into JSON yet")
return None
except json.decoder.JSONDecodeError:
logger.debug("unable to parse JSON")
@@ -308,19 +331,23 @@ class Hermes2ProToolParser(ToolParser):
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if (current_tool_call is None):
if current_tool_call is None:
return None
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
)
else:
return None
# case -- otherwise, send the tool call delta
@@ -329,15 +356,19 @@ class Hermes2ProToolParser(ToolParser):
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = DeltaMessage(content=delta_text) \
if text_portion is not None else None
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
@@ -346,8 +377,9 @@ class Hermes2ProToolParser(ToolParser):
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = (
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
@@ -361,8 +393,10 @@ class Hermes2ProToolParser(ToolParser):
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
@@ -378,38 +412,41 @@ class Hermes2ProToolParser(ToolParser):
# {"search_request": {}}
function_name = current_tool_call.get("name")
match = re.search(
r'\{"name":\s*"' +
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
tool_call_portion.strip(), re.DOTALL)
r'\{"name":\s*"'
+ re.escape(function_name)
+ r'"\s*,\s*"arguments":\s*(.*)',
tool_call_portion.strip(),
re.DOTALL,
)
if match:
cur_arguments_json = match.group(1)
else:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
logger.debug("finding %s in %s", delta_text,
cur_arguments_json)
logger.debug("finding %s in %s", delta_text, cur_arguments_json)
# get the location where previous args differ from current.
if (delta_text not in cur_arguments_json):
if delta_text not in cur_arguments_json:
return None
args_delta_start_loc = cur_arguments_json. \
rindex(delta_text) + \
len(delta_text)
args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(
delta_text
)
# use that to find the actual delta
arguments_delta = cur_arguments_json[:args_delta_start_loc]
logger.debug("First tokens in arguments received: %s",
arguments_delta)
logger.debug("First tokens in arguments received: %s", arguments_delta)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= arguments_delta
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
@@ -423,28 +460,32 @@ class Hermes2ProToolParser(ToolParser):
# if the delta_text ends with a '}' and tool_call_portion is a
# complete JSON, then the last '}' does not belong to the
# arguments, so we should trim it off
if isinstance(delta_text, str) \
and len(delta_text.rstrip()) >= 1 \
and delta_text.rstrip()[-1] == '}' \
and is_complete_json:
if (
isinstance(delta_text, str)
and len(delta_text.rstrip()) >= 1
and delta_text.rstrip()[-1] == "}"
and is_complete_json
):
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_text).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= delta_text
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=delta_text).model_dump(
exclude_none=True
),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += delta_text
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = \
current_tool_call
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)