[Voxtral] Fix speech transcription api (#31388)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: bk-201 <joy25810@foxmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: bk-201 <joy25810@foxmail.com> Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Anexdeus <5142168@mail.ru> Co-authored-by: Julien Denize <40604584+juliendenize@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2972a05473
commit
18d4e481d0
@@ -469,8 +469,10 @@ class WhisperEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = WhisperCausalConv1d if is_causal else partial(nn.Conv1d, padding=1)
|
||||
self.is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = (
|
||||
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
|
||||
)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
@@ -485,7 +487,7 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
if is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
@@ -536,8 +538,11 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched:
|
||||
if input_is_batched or self.is_causal:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
# If WhisperEncoder is causal, sequences
|
||||
# are not padded to have identical seq length (T)
|
||||
# => concat over feature dim
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user