[Bugfix] properly handle nested json with llama3 tool parser (#27701)
Signed-off-by: Aydin Abiar <aydin@anyscale.com> Signed-off-by: Aydin Abiar <62435714+Aydin-ab@users.noreply.github.com> Co-authored-by: Aydin Abiar <aydin@anyscale.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
@@ -56,12 +57,10 @@ class Llama3JsonToolParser(ToolParser):
|
||||
self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
|
||||
0
|
||||
]
|
||||
# Updated regex to match multiple JSONs separated by semicolons
|
||||
# This pattern is more robust and can handle nested JSON objects
|
||||
self.tool_call_regex = re.compile(
|
||||
r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
|
||||
re.DOTALL,
|
||||
)
|
||||
# Simple regex to find opening braces - we'll use JSON decoder for parsing
|
||||
# This handles arbitrary nesting depth correctly
|
||||
self.tool_call_start_regex = re.compile(r"\{")
|
||||
self.json_decoder = json.JSONDecoder()
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
@@ -77,49 +76,84 @@ class Llama3JsonToolParser(ToolParser):
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# Find JSON object(s) in the text using regex
|
||||
match = self.tool_call_regex.search(model_output)
|
||||
if not match:
|
||||
# Keep track of the end index of the last parsed JSON object
|
||||
# so we don't parse inner brackets
|
||||
end_index = -1
|
||||
tool_calls: list[ToolCall] = []
|
||||
|
||||
try:
|
||||
for match in self.tool_call_start_regex.finditer(
|
||||
model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
|
||||
):
|
||||
start_index = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start_index <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end_index = self.json_decoder.raw_decode(
|
||||
model_output[start_index:]
|
||||
)
|
||||
end_index = start_index + json_end_index
|
||||
|
||||
# raise KeyError if missing
|
||||
name = obj["name"]
|
||||
arguments_or_params = (
|
||||
obj["arguments"] if "arguments" in obj else obj["parameters"]
|
||||
)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
arguments_or_params, ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
# Missing required key
|
||||
missing_key = str(e).strip("'\"")
|
||||
logger.exception(
|
||||
"Couldn't extract tool call from JSON response. "
|
||||
"Required key '%s' not present. "
|
||||
"Returning output in content with empty tool calls.",
|
||||
missing_key,
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
except Exception:
|
||||
# Any other error during parsing
|
||||
logger.exception(
|
||||
"Error in extracting tool call from response. "
|
||||
"Returning output in content with empty tool calls"
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
except TimeoutError:
|
||||
logger.warning("Regex timeout occurred when matching tool call pattern.")
|
||||
logger.debug(
|
||||
"Regex timeout occurred when matching user input: %s", model_output
|
||||
)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
json_str = match.group(0)
|
||||
# Split by semicolon and strip whitespace
|
||||
json_objects = [obj.strip() for obj in json_str.split(";")]
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for json_obj in json_objects:
|
||||
if not json_obj: # Skip empty strings
|
||||
continue
|
||||
obj = json.loads(json_obj)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=obj["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
obj["arguments"]
|
||||
if "arguments" in obj
|
||||
else obj["parameters"],
|
||||
ensure_ascii=False,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# If we have valid tool calls, return them normally
|
||||
if tool_calls:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True, tool_calls=tool_calls, content=None
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
# No valid tool calls found
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user