[Voxtral Realtime] Introduce global log mel max (#33574)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-02-02 23:01:47 +01:00
committed by GitHub
parent 089cd4f002
commit 5019c59dd2
4 changed files with 19 additions and 11 deletions

View File

@@ -782,7 +782,19 @@ class VoxtralEncoderModel(nn.Module):
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = self.mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if global_log_mel_max := self.config.global_log_mel_max:
if not isinstance(global_log_mel_max, float):
raise TypeError(f"{global_log_mel_max=} needs to be of type float.")
log_spec_max = torch.tensor(
global_log_mel_max,
device=log_spec.device,
dtype=log_spec.dtype,
)
else:
log_spec_max = log_spec.max()
log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.to(input_dtype)