[Core] Align whisper closer to other multimodal models (#27292)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@@ -599,15 +599,16 @@ class WhisperModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor | list[torch.Tensor] | None,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
encoder_outputs: list[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
encoder_outputs = self.get_encoder_outputs(input_features)
|
||||
assert len(encoder_outputs) in (0, 1)
|
||||
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
encoder_hidden_states=enc_states,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
@@ -894,13 +895,15 @@ class WhisperForConditionalGeneration(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
encoder_outputs: list[torch.Tensor] | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
if encoder_outputs is None:
|
||||
encoder_outputs = []
|
||||
decoder_outputs = self.model(
|
||||
input_features=audio_input["input_features"],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
encoder_outputs=encoder_outputs,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user