[Frontend] Implement Tool Calling with tool_choice='required' (#13483)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
|
||||
field_names = set()
|
||||
for field_name, field in cls.model_fields.items():
|
||||
field_names.add(field_name)
|
||||
if alias := getattr(field, 'alias', None):
|
||||
if alias := getattr(field, "alias", None):
|
||||
field_names.add(alias)
|
||||
cls.field_names = field_names
|
||||
|
||||
@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
|
||||
logger.warning(
|
||||
"The following fields were present in the request "
|
||||
"but ignored: %s",
|
||||
data.keys() - field_names)
|
||||
data.keys() - field_names,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: Optional[list[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
tool_choice: Optional[Union[
|
||||
Literal["none"],
|
||||
Literal["auto"],
|
||||
Literal["required"],
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
]] = "none"
|
||||
|
||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."))
|
||||
"through out the inference process and return in response."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
|
||||
@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
if self.tool_choice == "required":
|
||||
# Pydantic schema generation cannot be used since the JSON schema
|
||||
# has to be constructed for a specific instantiation of a tool list
|
||||
# so that parameters of a function are correctly generated
|
||||
# based on the chosen function name
|
||||
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"enum": [tool.function.name]
|
||||
},
|
||||
# parameters are always generated as '{}' in the final
|
||||
# output if they are missing from the request
|
||||
# (i.e. are None or '{}') so the schema is
|
||||
# updated to produce an empty object in that case
|
||||
"parameters": tool.function.parameters
|
||||
if tool.function.parameters else {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
},
|
||||
"required": ["name", "parameters"]
|
||||
}
|
||||
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [get_tool_schema(tool) for tool in self.tools]
|
||||
}
|
||||
}
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
if guide_count > 1 and data.get("tool_choice", "none") not in (
|
||||
"none",
|
||||
"auto",
|
||||
"required",
|
||||
):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool, \"auto\", "
|
||||
"or \"none\".")
|
||||
# OR that it's set to "auto" or "required"
|
||||
if data["tool_choice"] not in [
|
||||
"auto", "required"
|
||||
] and not isinstance(data["tool_choice"], dict):
|
||||
raise NotImplementedError(
|
||||
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
|
||||
'Only named tools, "none", "auto" or "required" '\
|
||||
'are supported.'
|
||||
)
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
return_tokens_as_token_ids: Optional[bool] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
# doc: end-chat-embedding-extra-params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-score-extra-params
|
||||
|
||||
@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
|
||||
@@ -2,13 +2,16 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Callable, Final, Optional, Union
|
||||
|
||||
import jinja2
|
||||
import partial_json_parser
|
||||
from fastapi import Request
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
@@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
|
||||
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
# validation for OpenAI tools
|
||||
# tool_choice = "required" is not supported
|
||||
if request.tool_choice == "required":
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
@@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.response_role
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
@staticmethod
|
||||
def _bracket_level(s: str, opening='{', closing='}') -> int:
|
||||
"""
|
||||
Calculate the current level of nested brackets in a given string.
|
||||
"""
|
||||
level = 0
|
||||
for char in s:
|
||||
if char == opening:
|
||||
level += 1
|
||||
elif char == closing:
|
||||
level -= 1
|
||||
return level
|
||||
|
||||
@staticmethod
|
||||
def _filter_delta_text(delta_text: str,
|
||||
previous_text: str) -> tuple[str, bool]:
|
||||
# remove last '},' of the tool definition stemming from the
|
||||
# "name"/"parameters" outer object or closing ']' of the tool list
|
||||
# count occurrences of opening and closing curly braces and
|
||||
# once level 0 is reached stop outputting text
|
||||
# if 0 is reached while parsing the delta_text we know the current
|
||||
# tool will finish in this current iteration
|
||||
bracket_level = OpenAIServingChat._bracket_level(previous_text)
|
||||
updated_delta, passed_zero = "", False
|
||||
for c in delta_text:
|
||||
if c == '{':
|
||||
bracket_level += 1
|
||||
passed_zero = bracket_level == 0
|
||||
elif c == '}':
|
||||
bracket_level -= 1
|
||||
passed_zero = bracket_level == 0
|
||||
|
||||
if bracket_level != 0:
|
||||
updated_delta += c
|
||||
else:
|
||||
# if a comma is reached at level 0 we can stop
|
||||
if c == ',':
|
||||
break
|
||||
return updated_delta, passed_zero
|
||||
|
||||
def extract_tool_call_required_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
function_name_returned: bool,
|
||||
) -> tuple[Optional[DeltaMessage], bool]:
|
||||
try:
|
||||
obj = partial_json_parser.loads(current_text)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
obj = None
|
||||
|
||||
# check if the current text is a valid array
|
||||
# containing a partial tool calling object
|
||||
# if not repeat
|
||||
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
# take the last tool call from the generated list
|
||||
current_tool_call = obj[-1]
|
||||
|
||||
# once parameters have been generated the name is complete as well
|
||||
if not finishes_previous_tool and ("name" not in current_tool_call
|
||||
or "parameters"
|
||||
not in current_tool_call):
|
||||
function_name_returned = False
|
||||
delta_message = None
|
||||
else:
|
||||
if not function_name_returned:
|
||||
# get partly generated arguments from the latest tool call
|
||||
param_match = re.search(r'.*"parameters":\s*(.*)',
|
||||
current_text)
|
||||
arguments = param_match.group(1) if param_match else ""
|
||||
arguments, _ = OpenAIServingChat._filter_delta_text(
|
||||
arguments, previous_text)
|
||||
|
||||
# if this iteration finishes a previous tool call but a
|
||||
# new incomplete tool is already generated, take the
|
||||
# previous from the list
|
||||
if (finishes_previous_tool
|
||||
and "parameters" not in current_tool_call):
|
||||
current_tool_call = obj[-2]
|
||||
|
||||
function_name_returned = True
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
|
||||
else:
|
||||
delta_text, _ = OpenAIServingChat._filter_delta_text(
|
||||
delta_text, previous_text)
|
||||
|
||||
if delta_text != "":
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
function=DeltaFunctionCall(
|
||||
# OpenAI API returns None
|
||||
# instead of name every time
|
||||
name=None,
|
||||
arguments=delta_text),
|
||||
index=len(obj) - 1,
|
||||
type="function")
|
||||
])
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
return delta_message, function_name_returned
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self._should_stream_with_reasoning_parsing(request))
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned: Optional[list[bool]] = None
|
||||
|
||||
# Only one of these will be used, thus previous_texts and
|
||||
# all_previous_token_ids will not be used twice in the same iteration.
|
||||
@@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# For reasoning parser and tool call all enabled
|
||||
added_content_delta_arr = [False] * num_choices
|
||||
reasoning_end_arr = [False] * num_choices
|
||||
elif request.tool_choice == "required":
|
||||
previous_texts = [""] * num_choices
|
||||
function_name_returned = [False] * num_choices
|
||||
all_previous_token_ids = None
|
||||
else:
|
||||
previous_texts, all_previous_token_ids = None, None
|
||||
|
||||
@@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i)
|
||||
])
|
||||
|
||||
elif request.tool_choice == "required":
|
||||
assert previous_texts is not None
|
||||
assert function_name_returned is not None
|
||||
previous_text = previous_texts[i]
|
||||
current_text = previous_text + delta_text
|
||||
fn_name_returned = function_name_returned[i]
|
||||
|
||||
delta_message, function_name_returned[i] = (
|
||||
self.extract_tool_call_required_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
|
||||
# handle streaming deltas for tools with "auto" tool choice
|
||||
# and reasoning parser
|
||||
elif tool_choice_auto and self.enable_reasoning:
|
||||
@@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools
|
||||
or not self.tool_parser) and not isinstance(
|
||||
request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam):
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
(not isinstance(request.tool_choice,
|
||||
ChatCompletionNamedToolChoiceParam
|
||||
) and request.tool_choice != "required"):
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=content)
|
||||
@@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
|
||||
arguments=content))
|
||||
])
|
||||
|
||||
elif request.tool_choice and request.tool_choice == "required":
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
|
||||
# the fields of FunctionDefinition are a superset of the
|
||||
# tool call outputs and can be used for parsing
|
||||
tool_calls = TypeAdapter(
|
||||
list[FunctionDefinition]).validate_json(output.text)
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
elif not request.tool_choice or request.tool_choice == "none":
|
||||
|
||||
Reference in New Issue
Block a user