[Voxtral] Fix speech transcription api (#31388)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: bk-201 <joy25810@foxmail.com>
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: bk-201 <joy25810@foxmail.com>
Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com>
Co-authored-by: Anexdeus <5142168@mail.ru>
Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
Patrick von Platen
2026-01-08 12:34:19 +02:00
committed by GitHub
parent 2972a05473
commit 18d4e481d0
5 changed files with 114 additions and 27 deletions

View File

@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
is_causal = getattr(config, "is_causal", False)
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
self.is_causal = getattr(config, "is_causal", False)
Conv1d = (
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
)
self.layer_norm = nn.LayerNorm(config.d_model)
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched:
if input_is_batched or self.is_causal:
# Models using WhisperEncoder may handle batching internally.
# If WhisperEncoder is causal, sequences
# are not padded to have identical seq length (T)
# => concat over feature dim
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)