[Gemma3n] Fix audio batching (#24052)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-09-02 16:23:35 +02:00
committed by GitHub
parent 8bd5844989
commit 0a74e9d0f2
2 changed files with 63 additions and 7 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast
import numpy as np
import torch
# yapf: disable
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (Gemma3nAudioConfig,
@@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
MultiModalDataParser)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
input_features: Union[torch.Tensor, list[torch.Tensor]]
input_features_padded: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
@@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
mm_kwargs,
tok_kwargs,
)
if 'input_features' in processed_outputs:
# Avoid padding since we need the output of each item to be
# Padding enables audio_tower to run in batched mode
processed_outputs["input_features_padded"] = \
processed_outputs["input_features"]
# Unpad features here since we need the output of each item to be
# independent of other items for the cache to work correctly
unpadded_features = [
f[mask] for f, mask in zip(
@@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
def _get_prompt_updates(
self,
@@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if input_features_mask is None:
return None
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
return Gemma3nAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
input_features_padded=input_features_padded,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
audio_input: Gemma3nAudioInputs,
) -> list[torch.Tensor]:
assert self.audio_tower is not None
input_features = audio_input["input_features"].squeeze(1)
# Run on padded features to enable batching
input_features = audio_input["input_features_padded"].squeeze(1)
input_features_mask = audio_input["input_features_mask"].squeeze(1)
audio_outputs, audio_mask = self.audio_tower(input_features,
~input_features_mask)