[Voxtral Realtime] Introduce global log mel max (#33574)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
089cd4f002
commit
5019c59dd2
@@ -126,8 +126,8 @@ async def test_multi_chunk_streaming(
|
||||
assert event["type"] == "transcription.done"
|
||||
assert event["text"] == full_text
|
||||
assert full_text == (
|
||||
" He has first words I spoke in the original phonograph."
|
||||
" First words I spoke in the original phonograph."
|
||||
" A little piece of practical poetry. Mary had a little lamb,"
|
||||
" it squeaked with quite a flow, and everywhere that Mary went,"
|
||||
" it sleeps with quite a flow, and everywhere that Mary went,"
|
||||
" the lamb was sure to go"
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ EXPECTED_TEXT = [
|
||||
(
|
||||
" First words I spoke in the original phonograph. "
|
||||
"A little piece of practical poetry. Mary had a little lamb,"
|
||||
" it sleeps with quite a snow, and everywhere that Mary went, "
|
||||
" its fleece was quite a slow, and everywhere that Mary went, "
|
||||
"the lamb was sure to go."
|
||||
),
|
||||
(
|
||||
@@ -246,13 +246,6 @@ async def test_voxtral_realtime_generator(audio_assets, tokenizer, async_engine)
|
||||
|
||||
texts = [tokenizer.decode(output_tokens) for output_tokens in output_tokens_list]
|
||||
|
||||
# 'true' streaming and 'offline' streaming differ a bit because log-mels are
|
||||
# differently noramalized
|
||||
texts[0] = (
|
||||
texts[0]
|
||||
.replace("He has f", "F")
|
||||
.replace("its fleece was quite a slow", "it sleeps with quite a snow")
|
||||
)
|
||||
texts[1] = texts[1].replace("a base hit", "OBS").replace("oh my", "oh, my")
|
||||
|
||||
assert texts == EXPECTED_TEXT
|
||||
|
||||
@@ -782,7 +782,19 @@ class VoxtralEncoderModel(nn.Module):
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
mel_spec = self.mel_filters.T @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
|
||||
if global_log_mel_max := self.config.global_log_mel_max:
|
||||
if not isinstance(global_log_mel_max, float):
|
||||
raise TypeError(f"{global_log_mel_max=} needs to be of type float.")
|
||||
log_spec_max = torch.tensor(
|
||||
global_log_mel_max,
|
||||
device=log_spec.device,
|
||||
dtype=log_spec.dtype,
|
||||
)
|
||||
else:
|
||||
log_spec_max = log_spec.max()
|
||||
|
||||
log_spec = torch.maximum(log_spec, log_spec_max - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec.to(input_dtype)
|
||||
|
||||
|
||||
@@ -248,6 +248,9 @@ def _remap_mistral_audio_args(config: dict) -> dict:
|
||||
sliding_window=encoder_args.get("sliding_window", None),
|
||||
block_pool_size=block_pool_size,
|
||||
pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
|
||||
global_log_mel_max=encoder_args["audio_encoding_args"].get(
|
||||
"global_log_mel_max"
|
||||
),
|
||||
# only needed for RoPE
|
||||
max_position_embeddings=block_pool_size * config["max_position_embeddings"],
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user