Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,13 +10,19 @@ import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
ToolParser,
|
||||
ToolParserManager,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
||||
@@ -25,37 +31,41 @@ logger = init_logger(__name__)
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.error(
|
||||
"Detected Mistral tokenizer when using a Hermes model")
|
||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool: list[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.streamed_args_for_tool: list[
|
||||
str
|
||||
] = [] # map what has been streamed for each tool so far to a list
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL)
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
|
||||
)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL)
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
"constructor during construction."
|
||||
)
|
||||
self.tool_call_start_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_start_token, add_special_tokens=False)
|
||||
self.tool_call_start_token, add_special_tokens=False
|
||||
)
|
||||
self.tool_call_end_token_ids = self.model_tokenizer.encode(
|
||||
self.tool_call_end_token, add_special_tokens=False)
|
||||
self.tool_call_end_token, add_special_tokens=False
|
||||
)
|
||||
|
||||
self.tool_call_start_token_array = [
|
||||
self.model_tokenizer.decode([token_id])
|
||||
@@ -77,13 +87,17 @@ class Hermes2ProToolParser(ToolParser):
|
||||
def tool_call_delta_buffer(self, delta_text: str):
|
||||
# If the sequence of tool_call_start or tool_call_end tokens is not yet
|
||||
# complete, fill the buffer with the token and return "".
|
||||
if (delta_text in self.tool_call_start_token_array
|
||||
or delta_text in self.tool_call_end_token_array):
|
||||
if (
|
||||
delta_text in self.tool_call_start_token_array
|
||||
or delta_text in self.tool_call_end_token_array
|
||||
):
|
||||
# If delta_text is the last token of tool_call_start_token or
|
||||
# tool_call_end_token, empty the buffer and return
|
||||
# the buffered text + delta_text.
|
||||
if (delta_text == self.tool_call_start_token_array[-1]
|
||||
or delta_text == self.tool_call_end_token_array[-1]):
|
||||
if (
|
||||
delta_text == self.tool_call_start_token_array[-1]
|
||||
or delta_text == self.tool_call_end_token_array[-1]
|
||||
):
|
||||
buffered_text = self.buffered_delta_text
|
||||
self.buffered_delta_text = ""
|
||||
return buffered_text + delta_text
|
||||
@@ -98,9 +112,8 @@ class Hermes2ProToolParser(ToolParser):
|
||||
else:
|
||||
return delta_text
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
# prior to the call to the tool parser, it breaks tool calling.
|
||||
@@ -112,22 +125,19 @@ class Hermes2ProToolParser(ToolParser):
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = (
|
||||
self.tool_call_regex.findall(model_output))
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
@@ -141,24 +151,26 @@ class Hermes2ProToolParser(ToolParser):
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(function_call["arguments"],
|
||||
ensure_ascii=False)))
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[:model_output.
|
||||
find(self.tool_call_start_token)]
|
||||
content = model_output[: model_output.find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None)
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
@@ -177,10 +189,12 @@ class Hermes2ProToolParser(ToolParser):
|
||||
delta_text = self.tool_call_delta_buffer(delta_text)
|
||||
# If the last characters of previous_text
|
||||
# match self.buffered_delta_text, remove only the matching part.
|
||||
if (len(previous_text) >= len(self.buffered_delta_text)
|
||||
and previous_text[-len(self.buffered_delta_text):]
|
||||
== self.buffered_delta_text):
|
||||
previous_text = previous_text[:-len(self.buffered_delta_text)]
|
||||
if (
|
||||
len(previous_text) >= len(self.buffered_delta_text)
|
||||
and previous_text[-len(self.buffered_delta_text) :]
|
||||
== self.buffered_delta_text
|
||||
):
|
||||
previous_text = previous_text[: -len(self.buffered_delta_text)]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
@@ -191,50 +205,51 @@ class Hermes2ProToolParser(ToolParser):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_text.count(
|
||||
self.tool_call_start_token)
|
||||
prev_tool_start_count = previous_text.count(self.tool_call_start_token)
|
||||
prev_tool_end_count = previous_text.count(self.tool_call_end_token)
|
||||
cur_tool_start_count = current_text.count(
|
||||
self.tool_call_start_token)
|
||||
cur_tool_start_count = current_text.count(self.tool_call_start_token)
|
||||
cur_tool_end_count = current_text.count(self.tool_call_end_token)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
if (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
tool_call_portion = (
|
||||
full_text.split(self.tool_call_start_token)[-1]
|
||||
.split(self.tool_call_end_token)[0]
|
||||
.rstrip()
|
||||
)
|
||||
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case: if tool open & close tag counts don't match, we're doing
|
||||
# imaginary "else" block here
|
||||
# something with tools with this diff.
|
||||
# flags for partial JSON parting. exported constants from
|
||||
# "Allow" are handled via BIT MASK
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count
|
||||
):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[
|
||||
-1
|
||||
]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
@@ -248,42 +263,49 @@ class Hermes2ProToolParser(ToolParser):
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
elif (
|
||||
cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count
|
||||
):
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if (self.prev_tool_call_arr is None
|
||||
or len(self.prev_tool_call_arr) == 0):
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
elif (
|
||||
cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count
|
||||
):
|
||||
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
|
||||
logger.debug("attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
|
||||
if diff:
|
||||
diff = diff.encode('utf-8').decode(
|
||||
'unicode_escape') if diff is str else diff
|
||||
if ('"}' not in delta_text):
|
||||
diff = (
|
||||
diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str
|
||||
else diff
|
||||
)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
@@ -293,13 +315,14 @@ class Hermes2ProToolParser(ToolParser):
|
||||
return delta
|
||||
|
||||
try:
|
||||
|
||||
current_tool_call = partial_json_parser.loads(
|
||||
tool_call_portion or "{}",
|
||||
flags) if tool_call_portion else None
|
||||
current_tool_call = (
|
||||
partial_json_parser.loads(tool_call_portion or "{}", flags)
|
||||
if tool_call_portion
|
||||
else None
|
||||
)
|
||||
logger.debug("Parsed tool call %s", current_tool_call)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
logger.debug("not enough tokens to parse into JSON yet")
|
||||
return None
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.debug("unable to parse JSON")
|
||||
@@ -308,19 +331,23 @@ class Hermes2ProToolParser(ToolParser):
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if (current_tool_call is None):
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return None
|
||||
# case -- otherwise, send the tool call delta
|
||||
@@ -329,15 +356,19 @@ class Hermes2ProToolParser(ToolParser):
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = DeltaMessage(content=delta_text) \
|
||||
if text_portion is not None else None
|
||||
delta = (
|
||||
DeltaMessage(content=delta_text)
|
||||
if text_portion is not None
|
||||
else None
|
||||
)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
logger.debug(
|
||||
"Trying to parse current tool call with ID %s", self.current_tool_id
|
||||
)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
@@ -346,8 +377,9 @@ class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = (
|
||||
self.prev_tool_call_arr[self.current_tool_id].get("arguments"))
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
@@ -361,8 +393,10 @@ class Hermes2ProToolParser(ToolParser):
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
logger.error(
|
||||
"should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything."
|
||||
)
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
@@ -378,38 +412,41 @@ class Hermes2ProToolParser(ToolParser):
|
||||
# {"search_request": {}}
|
||||
function_name = current_tool_call.get("name")
|
||||
match = re.search(
|
||||
r'\{"name":\s*"' +
|
||||
re.escape(function_name) + r'"\s*,\s*"arguments":\s*(.*)',
|
||||
tool_call_portion.strip(), re.DOTALL)
|
||||
r'\{"name":\s*"'
|
||||
+ re.escape(function_name)
|
||||
+ r'"\s*,\s*"arguments":\s*(.*)',
|
||||
tool_call_portion.strip(),
|
||||
re.DOTALL,
|
||||
)
|
||||
if match:
|
||||
cur_arguments_json = match.group(1)
|
||||
else:
|
||||
cur_arguments_json = json.dumps(cur_arguments,
|
||||
ensure_ascii=False)
|
||||
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)
|
||||
|
||||
logger.debug("finding %s in %s", delta_text,
|
||||
cur_arguments_json)
|
||||
logger.debug("finding %s in %s", delta_text, cur_arguments_json)
|
||||
|
||||
# get the location where previous args differ from current.
|
||||
if (delta_text not in cur_arguments_json):
|
||||
if delta_text not in cur_arguments_json:
|
||||
return None
|
||||
args_delta_start_loc = cur_arguments_json. \
|
||||
rindex(delta_text) + \
|
||||
len(delta_text)
|
||||
args_delta_start_loc = cur_arguments_json.rindex(delta_text) + len(
|
||||
delta_text
|
||||
)
|
||||
|
||||
# use that to find the actual delta
|
||||
arguments_delta = cur_arguments_json[:args_delta_start_loc]
|
||||
logger.debug("First tokens in arguments received: %s",
|
||||
arguments_delta)
|
||||
logger.debug("First tokens in arguments received: %s", arguments_delta)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= arguments_delta
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
@@ -423,28 +460,32 @@ class Hermes2ProToolParser(ToolParser):
|
||||
# if the delta_text ends with a '}' and tool_call_portion is a
|
||||
# complete JSON, then the last '}' does not belong to the
|
||||
# arguments, so we should trim it off
|
||||
if isinstance(delta_text, str) \
|
||||
and len(delta_text.rstrip()) >= 1 \
|
||||
and delta_text.rstrip()[-1] == '}' \
|
||||
and is_complete_json:
|
||||
if (
|
||||
isinstance(delta_text, str)
|
||||
and len(delta_text.rstrip()) >= 1
|
||||
and delta_text.rstrip()[-1] == "}"
|
||||
and is_complete_json
|
||||
):
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_text).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= delta_text
|
||||
delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=delta_text).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += delta_text
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[self.current_tool_id] = \
|
||||
current_tool_call
|
||||
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user