[Core] Align whisper closer to other multimodal models (#27292)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2025-11-21 07:01:54 -05:00
committed by GitHub
parent aab0102a26
commit cca2d2cdbe
2 changed files with 21 additions and 41 deletions

View File

@@ -599,15 +599,16 @@ class WhisperModel(nn.Module):
def forward(
self,
input_features: torch.Tensor | list[torch.Tensor] | None,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs(input_features)
assert len(encoder_outputs) in (0, 1)
enc_states = encoder_outputs[0] if len(encoder_outputs) == 1 else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=encoder_outputs,
encoder_hidden_states=enc_states,
)
return decoder_outputs
@@ -894,13 +895,15 @@ class WhisperForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if encoder_outputs is None:
encoder_outputs = []
decoder_outputs = self.model(
input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
encoder_outputs=encoder_outputs,
)
return decoder_outputs