[Misc][OpenAI] deprecate max_tokens in favor of new max_completion_tokens field for chat completion endpoint (#9837)

This commit is contained in:
Guillaume Calmettes
2024-10-31 02:15:56 +01:00
committed by GitHub
parent 64384bbcdf
commit abbfb6134d
14 changed files with 140 additions and 118 deletions

View File

@@ -159,7 +159,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
max_tokens: Optional[int] = None
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
deprecated=
'max_tokens is deprecated in favor of the max_completion_tokens field')
max_completion_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
@@ -295,7 +300,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
@@ -311,7 +317,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = self.max_completion_tokens or self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

View File

@@ -263,20 +263,26 @@ class OpenAIServing:
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = request.max_tokens
if max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + request.max_tokens > self.max_model_len:
elif token_num + max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
f"{max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)