[Core] Whisper Enable Encoder Batching (#29421)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -522,6 +522,7 @@ class WhisperEncoder(nn.Module):
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = []
|
||||
input_is_batched = False
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
@@ -530,7 +531,13 @@ class WhisperEncoder(nn.Module):
|
||||
embeds.dtype
|
||||
)
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
@@ -603,8 +610,7 @@ class WhisperModel(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
encoder_outputs: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
assert len(encoder_outputs) in (0, 1)
|
||||
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
|
||||
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -913,7 +919,10 @@ class WhisperForConditionalGeneration(
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
# Required as part of SupportsMultiModal interface.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
return [self.model.get_encoder_outputs(audio_input["input_features"])]
|
||||
# Split concatenated encoder outputs into one tensor per audio input
|
||||
enc_output = self.model.get_encoder_outputs(audio_input["input_features"])
|
||||
# The assumption is we can only process whole mm items (audios)
|
||||
return enc_output.unbind(dim=0)
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user