[responsesAPI]add extra body parameters (#30532)
Signed-off-by: Ri0S <aa248424@gmail.com>
This commit is contained in:
@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
max_tool_calls: int | None = None
|
max_tool_calls: int | None = None
|
||||||
metadata: Metadata | None = None
|
metadata: Metadata | None = None
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
|
logit_bias: dict[str, float] | None = None
|
||||||
parallel_tool_calls: bool | None = True
|
parallel_tool_calls: bool | None = True
|
||||||
previous_response_id: str | None = None
|
previous_response_id: str | None = None
|
||||||
prompt: ResponsePrompt | None = None
|
prompt: ResponsePrompt | None = None
|
||||||
@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
tools: list[Tool] = Field(default_factory=list)
|
tools: list[Tool] = Field(default_factory=list)
|
||||||
top_logprobs: int | None = 0
|
top_logprobs: int | None = 0
|
||||||
top_p: float | None = None
|
top_p: float | None = None
|
||||||
|
top_k: int | None = None
|
||||||
truncation: Literal["auto", "disabled"] | None = "disabled"
|
truncation: Literal["auto", "disabled"] | None = "disabled"
|
||||||
user: str | None = None
|
user: str | None = None
|
||||||
|
|
||||||
@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
_DEFAULT_SAMPLING_PARAMS = {
|
_DEFAULT_SAMPLING_PARAMS = {
|
||||||
"temperature": 1.0,
|
"temperature": 1.0,
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
|
"top_k": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(
|
||||||
@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
top_p = default_sampling_params.get(
|
top_p = default_sampling_params.get(
|
||||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
|
||||||
)
|
)
|
||||||
|
if (top_k := self.top_k) is None:
|
||||||
|
top_k = default_sampling_params.get(
|
||||||
|
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
|
||||||
|
)
|
||||||
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||||
|
|
||||||
# Structured output
|
# Structured output
|
||||||
@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
|
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
|
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
|
||||||
),
|
),
|
||||||
structured_outputs=structured_outputs,
|
structured_outputs=structured_outputs,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_include_output_logprobs(self) -> bool:
|
def is_include_output_logprobs(self) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user