[Frontend] Implement Tool Calling with tool_choice='required' (#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Matthias Matt
2025-04-02 16:45:45 +02:00
committed by GitHub
parent 98d7367b61
commit cefb9e5a28
7 changed files with 868 additions and 93 deletions

View File

@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, 'alias', None):
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning(
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
data.keys() - field_names,
)
return result
@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
tool_choice: Optional[Union[
Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
"through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name]
return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None
@model_validator(mode="before")
@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
if guide_count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool, \"auto\", "
"or \"none\".")
# OR that it's set to "auto" or "required"
if data["tool_choice"] not in [
"auto", "required"
] and not isinstance(data["tool_choice"], dict):
raise NotImplementedError(
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params
@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params
@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params

View File

@@ -2,13 +2,16 @@
import asyncio
import json
import re
import time
from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence
from typing import Callable, Final, Optional, Union
import jinja2
import partial_json_parser
from fastapi import Request
from pydantic import TypeAdapter
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
@@ -21,8 +24,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
DeltaToolCall, ErrorResponse, FunctionCall, FunctionDefinition,
PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@@ -150,12 +153,6 @@ class OpenAIServingChat(OpenAIServing):
tool_parser = self.tool_parser
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if isinstance(tokenizer, MistralTokenizer):
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
@@ -277,6 +274,122 @@ class OpenAIServingChat(OpenAIServing):
return self.response_role
return request.messages[-1]["role"]
@staticmethod
def _bracket_level(s: str, opening='{', closing='}') -> int:
"""
Calculate the current level of nested brackets in a given string.
"""
level = 0
for char in s:
if char == opening:
level += 1
elif char == closing:
level -= 1
return level
@staticmethod
def _filter_delta_text(delta_text: str,
previous_text: str) -> tuple[str, bool]:
# remove last '},' of the tool definition stemming from the
# "name"/"parameters" outer object or closing ']' of the tool list
# count occurrences of opening and closing curly braces and
# once level 0 is reached stop outputting text
# if 0 is reached while parsing the delta_text we know the current
# tool will finish in this current iteration
bracket_level = OpenAIServingChat._bracket_level(previous_text)
updated_delta, passed_zero = "", False
for c in delta_text:
if c == '{':
bracket_level += 1
passed_zero = bracket_level == 0
elif c == '}':
bracket_level -= 1
passed_zero = bracket_level == 0
if bracket_level != 0:
updated_delta += c
else:
# if a comma is reached at level 0 we can stop
if c == ',':
break
return updated_delta, passed_zero
def extract_tool_call_required_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
function_name_returned: bool,
) -> tuple[Optional[DeltaMessage], bool]:
try:
obj = partial_json_parser.loads(current_text)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug('not enough tokens to parse into JSON yet')
obj = None
# check if the current text is a valid array
# containing a partial tool calling object
# if not repeat
if obj is None or not isinstance(obj, list) or not len(obj) > 0:
function_name_returned = False
delta_message = None
else:
_, finishes_previous_tool = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
# take the last tool call from the generated list
current_tool_call = obj[-1]
# once parameters have been generated the name is complete as well
if not finishes_previous_tool and ("name" not in current_tool_call
or "parameters"
not in current_tool_call):
function_name_returned = False
delta_message = None
else:
if not function_name_returned:
# get partly generated arguments from the latest tool call
param_match = re.search(r'.*"parameters":\s*(.*)',
current_text)
arguments = param_match.group(1) if param_match else ""
arguments, _ = OpenAIServingChat._filter_delta_text(
arguments, previous_text)
# if this iteration finishes a previous tool call but a
# new incomplete tool is already generated, take the
# previous from the list
if (finishes_previous_tool
and "parameters" not in current_tool_call):
current_tool_call = obj[-2]
function_name_returned = True
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
index=len(obj) - 1,
type="function")
])
else:
delta_text, _ = OpenAIServingChat._filter_delta_text(
delta_text, previous_text)
if delta_text != "":
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(
function=DeltaFunctionCall(
# OpenAI API returns None
# instead of name every time
name=None,
arguments=delta_text),
index=len(obj) - 1,
type="function")
])
else:
delta_message = None
return delta_message, function_name_returned
async def chat_completion_stream_generator(
self,
request: ChatCompletionRequest,
@@ -312,6 +425,7 @@ class OpenAIServingChat(OpenAIServing):
self._should_stream_with_reasoning_parsing(request))
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned: Optional[list[bool]] = None
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
@@ -322,6 +436,10 @@ class OpenAIServingChat(OpenAIServing):
# For reasoning parser and tool call all enabled
added_content_delta_arr = [False] * num_choices
reasoning_end_arr = [False] * num_choices
elif request.tool_choice == "required":
previous_texts = [""] * num_choices
function_name_returned = [False] * num_choices
all_previous_token_ids = None
else:
previous_texts, all_previous_token_ids = None, None
@@ -521,6 +639,23 @@ class OpenAIServingChat(OpenAIServing):
index=i)
])
elif request.tool_choice == "required":
assert previous_texts is not None
assert function_name_returned is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
fn_name_returned = function_name_returned[i]
delta_message, function_name_returned[i] = (
self.extract_tool_call_required_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
function_name_returned=fn_name_returned))
# update the previous values for the next iteration
previous_texts[i] = current_text
# handle streaming deltas for tools with "auto" tool choice
# and reasoning parser
elif tool_choice_auto and self.enable_reasoning:
@@ -821,10 +956,10 @@ class OpenAIServingChat(OpenAIServing):
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
if (not self.enable_auto_tools
or not self.tool_parser) and not isinstance(
request.tool_choice,
ChatCompletionNamedToolChoiceParam):
if (not self.enable_auto_tools or not self.tool_parser) and \
(not isinstance(request.tool_choice,
ChatCompletionNamedToolChoiceParam
) and request.tool_choice != "required"):
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=content)
@@ -845,6 +980,24 @@ class OpenAIServingChat(OpenAIServing):
arguments=content))
])
elif request.tool_choice and request.tool_choice == "required":
tool_call_class = MistralToolCall if isinstance(
tokenizer, MistralTokenizer) else ToolCall
# the fields of FunctionDefinition are a superset of the
# tool call outputs and can be used for parsing
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(output.text)
message = ChatMessage(
role=role,
content="",
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters)))
for tool_call in tool_calls
])
# if the request doesn't use tool choice
# OR specifies to not use a tool
elif not request.tool_choice or request.tool_choice == "none":