[Frontend] Add max-completion-token option to transcription/translation endpoints (#30769)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-12-16 20:36:49 +01:00
committed by GitHub
parent 10ee1c64cf
commit ca702a14dc
4 changed files with 79 additions and 2 deletions

View File

@@ -2054,6 +2054,9 @@ class TranscriptionRequest(OpenAIBaseModel):
presence_penalty: float | None = 0.0
"""The presence penalty to use for sampling."""
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:transcription-sampling-params]
# Default sampling parameters for transcription requests.
@@ -2300,6 +2303,9 @@ class TranslationRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
max_completion_tokens: int | None = None
"""The maximum number of tokens to generate."""
# --8<-- [end:translation-extra-params]
# Default sampling parameters for translation requests.

View File

@@ -293,8 +293,14 @@ class OpenAISpeechToText(OpenAIServing):
try:
# Unlike most decoder-only models, whisper generation length is not
# constrained by the size of the input audio, which is mapped to a
# fixed-size log-mel-spectogram.
default_max_tokens = self.model_config.max_model_len
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
# generated by respecting the extra completion tokens arg.
if request.max_completion_tokens is None:
default_max_tokens = self.model_config.max_model_len
else:
default_max_tokens = min(
self.model_config.max_model_len, request.max_completion_tokens
)
sampling_params = request.to_sampling_params(
default_max_tokens, self.default_sampling_params
)