[Voxtral Realtime] Introduce global log mel max (#33574)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
089cd4f002
commit
5019c59dd2
@@ -782,7 +782,19 @@ class VoxtralEncoderModel(nn.Module):
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
mel_spec = self.mel_filters.T @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
|
||||
if global_log_mel_max := self.config.global_log_mel_max:
|
||||
if not isinstance(global_log_mel_max, float):
|
||||
raise TypeError(f"{global_log_mel_max=} needs to be of type float.")
|
||||
log_spec_max = torch.tensor(
|
||||
global_log_mel_max,
|
||||
device=log_spec.device,
|
||||
dtype=log_spec.dtype,
|
||||
)
|
||||
else:
|
||||
log_spec_max = log_spec.max()
|
||||
|
||||
log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.to(input_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user