[Frontend] Add sampling params to v1/audio/transcriptions endpoint (#16591)

Signed-off-by: Jannis Schönleber <joennlae@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Jannis Schönleber <joennlae@gmail.com>
This commit is contained in:
Nicolò Lucchesi
2025-04-19 09:03:54 +02:00
committed by GitHub
parent 1d4680fad2
commit 2ef0dc53b8
4 changed files with 122 additions and 11 deletions

View File

@@ -1577,14 +1577,6 @@ class TranscriptionRequest(OpenAIBaseModel):
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
@@ -1596,6 +1588,7 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
"""
# doc: begin-transcription-extra-params
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
@@ -1604,10 +1597,51 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
# doc: end-transcription-extra-params
# doc: begin-transcription-sampling-params
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
top_p: Optional[float] = None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""
top_k: Optional[int] = None
"""Limits sampling to the `k` most probable tokens at each step."""
min_p: Optional[float] = None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
frequency_penalty: Optional[float] = 0.0
"""The frequency penalty to use for sampling."""
repetition_penalty: Optional[float] = None
"""The repetition penalty to use for sampling."""
presence_penalty: Optional[float] = 0.0
"""The presence penalty to use for sampling."""
# doc: end-transcription-sampling-params
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
"repetition_penalty": 1.0,
"temperature": 1.0,
"top_p": 1.0,
"top_k": -1,
"min_p": 0.0,
}
def to_sampling_params(
@@ -1619,13 +1653,35 @@ class TranscriptionRequest(OpenAIBaseModel):
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
if (min_p := self.min_p) is None:
min_p = default_sampling_params.get(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
if (repetition_penalty := self.repetition_penalty) is None:
repetition_penalty = default_sampling_params.get(
"repetition_penalty",
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
seed=self.seed,
top_p=top_p,
top_k=top_k,
min_p=min_p,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
presence_penalty=self.presence_penalty,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)