Fix Gemma3n audio encoder for Transformers v5 (#33673)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -621,10 +621,15 @@ class Gemma3nForConditionalGeneration(
|
|||||||
# Run on padded features to enable batching
|
# Run on padded features to enable batching
|
||||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||||
audio_outputs, audio_mask = self.audio_tower(
|
audio_outputs = self.audio_tower(input_features, ~input_features_mask)
|
||||||
input_features, ~input_features_mask
|
if isinstance(audio_outputs, tuple):
|
||||||
)
|
# Transformers v4
|
||||||
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
|
audio_encodings, audio_mask = audio_outputs
|
||||||
|
else:
|
||||||
|
# Transformers v5
|
||||||
|
audio_encodings = audio_outputs.last_hidden_state
|
||||||
|
audio_mask = audio_outputs.audio_mel_mask
|
||||||
|
audio_features = self.embed_audio(inputs_embeds=audio_encodings)
|
||||||
|
|
||||||
# The Gemma3nProcessor expects all audio will be 30s in length and
|
# The Gemma3nProcessor expects all audio will be 30s in length and
|
||||||
# inserts 188 audio soft tokens into the text to account for this.
|
# inserts 188 audio soft tokens into the text to account for this.
|
||||||
|
|||||||
Reference in New Issue
Block a user