[Gemma3n] Fix audio batching (#24052)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
# yapf: disable
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||
@@ -30,7 +31,6 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalPromptUpdates,
|
||||
@@ -62,7 +62,8 @@ class Gemma3nImagePixelInputs(TypedDict):
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: torch.Tensor
|
||||
input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
input_features_padded: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
@@ -188,8 +189,13 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
mm_kwargs,
|
||||
tok_kwargs,
|
||||
)
|
||||
|
||||
if 'input_features' in processed_outputs:
|
||||
# Avoid padding since we need the output of each item to be
|
||||
# Padding enables audio_tower to run in batched mode
|
||||
processed_outputs["input_features_padded"] = \
|
||||
processed_outputs["input_features"]
|
||||
|
||||
# Unpad features here since we need the output of each item to be
|
||||
# independent of other items for the cache to work correctly
|
||||
unpadded_features = [
|
||||
f[mask] for f, mask in zip(
|
||||
@@ -206,9 +212,11 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -516,9 +524,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if input_features_mask is None:
|
||||
return None
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
return Gemma3nAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
input_features_padded=input_features_padded,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@@ -564,7 +577,8 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
audio_input: Gemma3nAudioInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.audio_tower is not None
|
||||
input_features = audio_input["input_features"].squeeze(1)
|
||||
# Run on padded features to enable batching
|
||||
input_features = audio_input["input_features_padded"].squeeze(1)
|
||||
input_features_mask = audio_input["input_features_mask"].squeeze(1)
|
||||
audio_outputs, audio_mask = self.audio_tower(input_features,
|
||||
~input_features_mask)
|
||||
|
||||
Reference in New Issue
Block a user