[Frontend] Implement Tool Calling with tool_choice='required' (#13483)
Signed-off-by: Liangfu Chen <liangfc@amazon.com> Signed-off-by: Matt, Matthias <matthias.matt@tuwien.ac.at> Co-authored-by: Liangfu Chen <liangfc@amazon.com> Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -61,7 +61,7 @@ class OpenAIBaseModel(BaseModel):
|
||||
field_names = set()
|
||||
for field_name, field in cls.model_fields.items():
|
||||
field_names.add(field_name)
|
||||
if alias := getattr(field, 'alias', None):
|
||||
if alias := getattr(field, "alias", None):
|
||||
field_names.add(alias)
|
||||
cls.field_names = field_names
|
||||
|
||||
@@ -70,7 +70,8 @@ class OpenAIBaseModel(BaseModel):
|
||||
logger.warning(
|
||||
"The following fields were present in the request "
|
||||
"but ignored: %s",
|
||||
data.keys() - field_names)
|
||||
data.keys() - field_names,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -234,8 +235,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: Optional[list[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
tool_choice: Optional[Union[
|
||||
Literal["none"],
|
||||
Literal["auto"],
|
||||
Literal["required"],
|
||||
ChatCompletionNamedToolChoiceParam,
|
||||
]] = "none"
|
||||
|
||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||
parallel_tool_calls: Optional[bool] = False
|
||||
@@ -340,24 +345,28 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."))
|
||||
"through out the inference process and return in response."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -415,13 +424,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
|
||||
@@ -475,7 +486,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@@ -522,6 +534,41 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
tool = tools[tool_name]
|
||||
return tool.parameters
|
||||
|
||||
if self.tool_choice == "required":
|
||||
# Pydantic schema generation cannot be used since the JSON schema
|
||||
# has to be constructed for a specific instantiation of a tool list
|
||||
# so that parameters of a function are correctly generated
|
||||
# based on the chosen function name
|
||||
def get_tool_schema(tool: ChatCompletionToolsParam) -> dict:
|
||||
return {
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"enum": [tool.function.name]
|
||||
},
|
||||
# parameters are always generated as '{}' in the final
|
||||
# output if they are missing from the request
|
||||
# (i.e. are None or '{}') so the schema is
|
||||
# updated to produce an empty object in that case
|
||||
"parameters": tool.function.parameters
|
||||
if tool.function.parameters else {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
},
|
||||
"required": ["name", "parameters"]
|
||||
}
|
||||
|
||||
json_schema = {
|
||||
"type": "array",
|
||||
"minItems": 1,
|
||||
"items": {
|
||||
"type": "object",
|
||||
"anyOf": [get_tool_schema(tool) for tool in self.tools]
|
||||
}
|
||||
}
|
||||
return json_schema
|
||||
|
||||
return None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -572,8 +619,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
# you can only either use guided decoding or tools, not both
|
||||
if guide_count > 1 and data.get("tool_choice",
|
||||
"none") not in ("none", "auto"):
|
||||
if guide_count > 1 and data.get("tool_choice", "none") not in (
|
||||
"none",
|
||||
"auto",
|
||||
"required",
|
||||
):
|
||||
raise ValueError(
|
||||
"You can only either use guided decoding or tools, not both.")
|
||||
return data
|
||||
@@ -602,12 +652,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
|
||||
# make sure that tool choice is either a named tool
|
||||
# OR that it's set to "auto"
|
||||
if data["tool_choice"] != "auto" and not isinstance(
|
||||
data["tool_choice"], dict):
|
||||
raise ValueError(
|
||||
"`tool_choice` must either be a named tool, \"auto\", "
|
||||
"or \"none\".")
|
||||
# OR that it's set to "auto" or "required"
|
||||
if data["tool_choice"] not in [
|
||||
"auto", "required"
|
||||
] and not isinstance(data["tool_choice"], dict):
|
||||
raise NotImplementedError(
|
||||
f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\
|
||||
'Only named tools, "none", "auto" or "required" '\
|
||||
'are supported.'
|
||||
)
|
||||
|
||||
# ensure that if "tool_choice" is specified as an object,
|
||||
# it matches a valid tool
|
||||
@@ -722,18 +775,21 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
"'outlines' / 'lm-format-enforcer'"),
|
||||
)
|
||||
guided_whitespace_pattern: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default whitespace pattern "
|
||||
"for guided json decoding."))
|
||||
"for guided json decoding."),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
logits_processors: Optional[LogitsProcessors] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -745,6 +801,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"arguments. For example: {'qualname': "
|
||||
"'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': "
|
||||
"{'param': 'value'}}."))
|
||||
|
||||
return_tokens_as_token_ids: Optional[bool] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@@ -789,13 +846,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
length_penalty=self.length_penalty,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
|
||||
def to_sampling_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None,
|
||||
) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
@@ -844,7 +903,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
grammar=self.guided_grammar,
|
||||
json_object=guided_json_object,
|
||||
backend=self.guided_decoding_backend,
|
||||
whitespace_pattern=self.guided_whitespace_pattern)
|
||||
whitespace_pattern=self.guided_whitespace_pattern,
|
||||
)
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
@@ -942,7 +1002,8 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-embedding-extra-params
|
||||
|
||||
@@ -995,7 +1056,8 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
# doc: end-chat-embedding-extra-params
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -1034,7 +1096,8 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-score-extra-params
|
||||
|
||||
@@ -1059,7 +1122,8 @@ class RerankRequest(OpenAIBaseModel):
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."))
|
||||
"if the served model does not use priority scheduling."),
|
||||
)
|
||||
|
||||
# doc: end-rerank-extra-params
|
||||
|
||||
|
||||
Reference in New Issue
Block a user