[Bugfix] Fix loading Music Flamingo (#35535)

Signed-off-by: Nick Cao <ncao@redhat.com>
This commit is contained in:
Nick Cao
2026-03-17 01:24:40 -04:00
committed by GitHub
parent 17c1bdf371
commit 20b14095a4
2 changed files with 10 additions and 7 deletions

View File

@@ -128,12 +128,6 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
# Keep a dummy freqs parameter for MusicFlamingo checkpoints.
self.pos_emb = nn.Module()
freqs = torch.empty(getattr(config, "num_mel_bins", 128))
self.pos_emb.register_parameter(
"freqs", nn.Parameter(freqs, requires_grad=False)
)
def forward(
self,

View File

@@ -21,6 +21,7 @@ from vllm.multimodal.processing import BaseProcessingInfo
from .audioflamingo3 import (
AudioFlamingo3DummyInputsBuilder,
AudioFlamingo3ForConditionalGeneration,
AudioFlamingo3MultiModalDataParser,
AudioFlamingo3MultiModalProcessor,
)
@@ -53,8 +54,16 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo):
hf_processor = self.get_hf_processor(**kwargs)
return hf_processor.feature_extractor
def get_data_parser(self):
feature_extractor = self.get_feature_extractor()
return AudioFlamingo3MultiModalDataParser(
target_sr=feature_extractor.sampling_rate,
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
return {"audio": 1}
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):