[Core] Whisper Enable Encoder Batching (#29421)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-12-11 22:06:51 +01:00
committed by GitHub
parent 90d6cf921f
commit 0efd9f867c
5 changed files with 87 additions and 25 deletions

View File

@@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = []
input_is_batched = False
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
@@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
embeds.dtype
)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched:
# Models using WhisperEncoder may handle batching internally.
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
@@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
assert len(encoder_outputs) in (0, 1)
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
@@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])]
# Split concatenated encoder outputs into one tensor per audio input
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
# The assumption is we can only process whole mm items (audios)
return enc_output.unbind(dim=0)
def embed_input_ids(
self,