[Frontend] Implement Tool Calling with tool_choice='required' (#13483)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -2,13 +2,16 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
from fastapi import Request
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
|
||||
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
@@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
@staticmethod
|
||||
def _bracket_level(s: str, opening='{', closing='}') -> int:
|
||||
"""
|
||||
Calculate the current level of nested brackets in a given string.
|
||||
"""
|
||||
level = 0
|
||||
for char in s:
|
||||
if char == opening:
|
||||
level += 1
|
||||
elif char == closing:
|
||||
level -= 1
|
||||
return level
|
||||
|
||||
@staticmethod
|
||||
def _filter_delta_text(delta_text: str,
|
||||
previous_text: str) -> tuple[str, bool]:
|
||||
# remove last '},' of the tool definition stemming from the
|
||||
# "name"/"parameters" outer object or closing ']' of the tool list
|
||||
# count occurrences of opening and closing curly braces and
|
||||
# once level 0 is reached stop outputting text
|
||||
# if 0 is reached while parsing the delta_text we know the current
|
||||
# tool will finish in this current iteration
|
||||
bracket_level = OpenAIServingChat._bracket_level(previous_text)
|
||||
updated_delta, passed_zero = "", False
|
||||
for c in delta_text:
|
||||
if c == '{':
|
||||
bracket_level += 1
|
||||
passed_zero = bracket_level == 0
|
||||
elif c == '}':
|
||||
bracket_level -= 1
|
||||
passed_zero = bracket_level == 0
|
||||
|
||||
if bracket_level != 0:
|
||||
updated_delta += c
|
||||
else:
|
||||
# if a comma is reached at level 0 we can stop
|
||||
if c == ',':
|
||||
break
|
||||
return updated_delta, passed_zero
|
||||
|
||||
def extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
function_name_returned: bool,
|
||||
) -> tuple[Optional[DeltaMessage], bool]:
|
||||
try:
|
||||
obj = partial_json_parser.loads(current_text)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
obj = None
|
||||
|
||||
# check if the current text is a valid array
|
||||
# containing a partial tool calling object
|
||||
# if not repeat
|
||||
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
# take the last tool call from the generated list
|
||||
current_tool_call = obj[-1]
|
||||
|
||||
# once parameters have been generated the name is complete as well
|
||||
if not finishes_previous_tool and ("name" not in current_tool_call
|
||||
or "parameters"
|
||||
not in current_tool_call):
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
if not function_name_returned:
|
||||
# get partly generated arguments from the latest tool call
|
||||
param_match = re.search(r'.*"parameters":\s*(.*)',
|
||||
current_text)
|
||||
arguments = param_match.group(1) if param_match else ""
|
||||
arguments, _ = OpenAIServingChat._filter_delta_text(
|
||||
arguments, previous_text)
|
||||
|
||||
# if this iteration finishes a previous tool call but a
|
||||
# new incomplete tool is already generated, take the
|
||||
# previous from the list
|
||||
if (finishes_previous_tool
|
||||
and "parameters" not in current_tool_call):
|
||||
current_tool_call = obj[-2]
|
||||
|
||||
function_name_returned = True
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
|
||||
else:
|
||||
delta_text, _ = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
|
||||
if delta_text != "":
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
# OpenAI API returns None
|
||||
# instead of name every time
|
||||
name=None,
|
||||
arguments=delta_text),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
return delta_message, function_name_returned
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self._should_stream_with_reasoning_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned: Optional[list[bool]] = None
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
@@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# For reasoning parser and tool call all enabled
|
||||
added_content_delta_arr = [False] * num_choices
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
previous_texts = [""] * num_choices
|
||||
function_name_returned = [False] * num_choices
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
@@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i)
|
||||
])
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
assert function_name_returned is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
|
||||
delta_message, function_name_returned[i] = (
|
||||
self.extract_tool_call_required_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
# and reasoning parser
|
||||
elif tool_choice_auto and self.enable_reasoning:
|
||||
@@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
(not isinstance(request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam
|
||||
) and request.tool_choice != "required"):
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
@@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
|
||||
arguments=content))
|
||||
])
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
|
||||
# the fields of FunctionDefinition are a superset of the
|
||||
# tool call outputs and can be used for parsing
|
||||
tool_calls = TypeAdapter(
|
||||
list[FunctionDefinition]).validate_json(output.text)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
Reference in New Issue
Block a user