[Model][Multimodal] Add explicit MusicFlamingo adapter (#32696)
Signed-off-by: WangHaoyuuu <mailwhaoyu@gmail.com>
This commit is contained in:
@@ -128,6 +128,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
super().__init__(config)
|
||||
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
# self.layer_norm is already initialized in super().__init__
|
||||
# Keep a dummy freqs parameter for MusicFlamingo checkpoints.
|
||||
self.pos_emb = nn.Module()
|
||||
freqs = torch.empty(getattr(config, "num_mel_bins", 128))
|
||||
self.pos_emb.register_parameter(
|
||||
"freqs", nn.Parameter(freqs, requires_grad=False)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -146,7 +152,8 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
).to(hidden_states.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
layer_outputs = layer(hidden_states, attention_mask)
|
||||
# Qwen2AudioEncoderLayer expects layer_head_mask as third arg.
|
||||
layer_outputs = layer(hidden_states, attention_mask, None)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# AvgPool (time/2) + LayerNorm
|
||||
|
||||
70
vllm/model_executor/models/musicflamingo.py
Normal file
70
vllm/model_executor/models/musicflamingo.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""MusicFlamingo model adapter.
|
||||
|
||||
MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same
|
||||
implementation and multimodal processor, while accepting MusicFlamingo config
|
||||
and processor classes when available.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from transformers.models.audioflamingo3 import (
|
||||
AudioFlamingo3Config,
|
||||
AudioFlamingo3Processor,
|
||||
)
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import BaseProcessingInfo
|
||||
|
||||
from .audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3ForConditionalGeneration,
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
# Optional dependency: use MusicFlamingo classes when transformers provides them.
|
||||
from transformers.models.musicflamingo import (
|
||||
MusicFlamingoConfig,
|
||||
MusicFlamingoProcessor,
|
||||
)
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
MusicFlamingoConfig = None
|
||||
MusicFlamingoProcessor = None
|
||||
|
||||
|
||||
class MusicFlamingoProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
if MusicFlamingoConfig is None:
|
||||
return self.ctx.get_hf_config(AudioFlamingo3Config)
|
||||
return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config))
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
if MusicFlamingoProcessor is None:
|
||||
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
|
||||
# Tuple triggers AutoProcessor path and accepts either processor class.
|
||||
return self.ctx.get_hf_processor(
|
||||
(MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs
|
||||
)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
return hf_processor.feature_extractor
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
|
||||
pass
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
info=MusicFlamingoProcessingInfo,
|
||||
dummy_inputs=MusicFlamingoDummyInputsBuilder,
|
||||
)
|
||||
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
|
||||
"""MusicFlamingo model for conditional generation."""
|
||||
@@ -286,6 +286,10 @@ _MULTIMODAL_MODELS = {
|
||||
"audioflamingo3",
|
||||
"AudioFlamingo3ForConditionalGeneration",
|
||||
),
|
||||
"MusicFlamingoForConditionalGeneration": (
|
||||
"musicflamingo",
|
||||
"MusicFlamingoForConditionalGeneration",
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": (
|
||||
"aya_vision",
|
||||
"AyaVisionForConditionalGeneration",
|
||||
|
||||
Reference in New Issue
Block a user