[Voxtral Realtime] Introduce global log mel max (#33574)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-02-02 23:01:47 +01:00
committed by GitHub
parent 089cd4f002
commit 5019c59dd2
4 changed files with 19 additions and 11 deletions

View File

@@ -126,8 +126,8 @@ async def test_multi_chunk_streaming(
assert event["type"] == "transcription.done"
assert event["text"] == full_text
assert full_text == (
" He has first words I spoke in the original phonograph."
" First words I spoke in the original phonograph."
" A little piece of practical poetry. Mary had a little lamb,"
" it squeaked with quite a flow, and everywhere that Mary went,"
" it sleeps with quite a flow, and everywhere that Mary went,"
" the lamb was sure to go"
)

View File

@@ -37,7 +37,7 @@ EXPECTED_TEXT = [
(
" First words I spoke in the original phonograph. "
"A little piece of practical poetry. Mary had a little lamb,"
" it sleeps with quite a snow, and everywhere that Mary went, "
" its fleece was quite a slow, and everywhere that Mary went, "
"the lamb was sure to go."
),
(
@@ -246,13 +246,6 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
# 'true' streaming and 'offline' streaming differ a bit because log-mels are
# differently noramalized
texts[0] = (
texts[0]
.replace("He has f", "F")
.replace("its fleece was quite a slow", "it sleeps with quite a snow")
)
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
assert texts == EXPECTED_TEXT

View File

@@ -782,7 +782,19 @@ class VoxtralEncoderModel(nn.Module):
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = self.mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if global_log_mel_max := self.config.global_log_mel_max:
if not isinstance(global_log_mel_max, float):
raise TypeError(f"{global_log_mel_max=} needs to be of type float.")
log_spec_max = torch.tensor(
global_log_mel_max,
device=log_spec.device,
dtype=log_spec.dtype,
)
else:
log_spec_max = log_spec.max()
log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.to(input_dtype)

View File

@@ -248,6 +248,9 @@ def _remap_mistral_audio_args(config: dict) -> dict:
sliding_window=encoder_args.get("sliding_window", None),
block_pool_size=block_pool_size,
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
global_log_mel_max=encoder_args["audio_encoding_args"].get(
"global_log_mel_max"
),
# only needed for RoPE
max_position_embeddings=block_pool_size * config["max_position_embeddings"],
),