[Frontend] Implement Tool Calling with tool_choice='required' (#13483)

Signed-off-by: Liangfu Chen <liangfc@amazon.com>
Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at>
Co-authored-by: Liangfu Chen <liangfc@amazon.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Matthias Matt
2025-04-02 16:45:45 +02:00
committed by GitHub
parent 98d7367b61
commit cefb9e5a28
7 changed files with 868 additions and 93 deletions

View File

@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
field_names = set()
for field_name, field in cls.model_fields.items():
field_names.add(field_name)
if alias := getattr(field, 'alias', None):
if alias := getattr(field, "alias", None):
field_names.add(alias)
cls.field_names = field_names
@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
logger.warning(
"The following fields were present in the request "
"but ignored: %s",
data.keys() - field_names)
data.keys() - field_names,
)
return result
@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: Optional[float] = None
top_p: Optional[float] = None
tools: Optional[list[ChatCompletionToolsParam]] = None
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
ChatCompletionNamedToolChoiceParam]] = "none"
tool_choice: Optional[Union[
Literal["none"],
Literal["auto"],
Literal["required"],
ChatCompletionNamedToolChoiceParam,
]] = "none"
# NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False
@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."))
"through out the inference process and return in response."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
tool = tools[tool_name]
return tool.parameters
if self.tool_choice == "required":
# Pydantic schema generation cannot be used since the JSON schema
# has to be constructed for a specific instantiation of a tool list
# so that parameters of a function are correctly generated
# based on the chosen function name
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
return {
"properties": {
"name": {
"type": "string",
"enum": [tool.function.name]
},
# parameters are always generated as '{}' in the final
# output if they are missing from the request
# (i.e. are None or '{}') so the schema is
# updated to produce an empty object in that case
"parameters": tool.function.parameters
if tool.function.parameters else {
"type": "object",
"properties": {}
}
},
"required": ["name", "parameters"]
}
json_schema = {
"type": "array",
"minItems": 1,
"items": {
"type": "object",
"anyOf": [get_tool_schema(tool) for tool in self.tools]
}
}
return json_schema
return None
@model_validator(mode="before")
@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
# you can only either use guided decoding or tools, not both
if guide_count > 1 and data.get("tool_choice",
"none") not in ("none", "auto"):
if guide_count > 1 and data.get("tool_choice", "none") not in (
"none",
"auto",
"required",
):
raise ValueError(
"You can only either use guided decoding or tools, not both.")
return data
@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set.")
# make sure that tool choice is either a named tool
# OR that it's set to "auto"
if data["tool_choice"] != "auto" and not isinstance(
data["tool_choice"], dict):
raise ValueError(
"`tool_choice` must either be a named tool, \"auto\", "
"or \"none\".")
# OR that it's set to "auto" or "required"
if data["tool_choice"] not in [
"auto", "required"
] and not isinstance(data["tool_choice"], dict):
raise NotImplementedError(
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
'Only named tools, "none", "auto" or "required" '\
'are supported.'
)
# ensure that if "tool_choice" is specified as an object,
# it matches a valid tool
@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
"'outlines' / 'lm-format-enforcer'"),
)
guided_whitespace_pattern: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default whitespace pattern "
"for guided json decoding."))
"for guided json decoding."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
logits_processors: Optional[LogitsProcessors] = Field(
default=None,
description=(
@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
"arguments. For example: {'qualname': "
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
"{'param': 'value'}}."))
return_tokens_as_token_ids: Optional[bool] = Field(
default=None,
description=(
@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
include_stop_str_in_output=self.include_stop_str_in_output)
include_stop_str_in_output=self.include_stop_str_in_output,
)
def to_sampling_params(
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None) -> SamplingParams:
self,
default_max_tokens: int,
logits_processor_pattern: Optional[str],
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
max_tokens = self.max_tokens
if default_sampling_params is None:
@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
whitespace_pattern=self.guided_whitespace_pattern,
)
return SamplingParams.from_optional(
n=self.n,
@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-embedding-extra-params
@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-score-extra-params
@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
"if the served model does not use priority scheduling."),
)
# doc: end-rerank-extra-params