[Bugfix] Fix loading Music Flamingo (#35535)
Signed-off-by: Nick Cao <ncao@redhat.com>
This commit is contained in:
@@ -128,12 +128,6 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
super().__init__(config)
|
||||
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
# self.layer_norm is already initialized in super().__init__
|
||||
# Keep a dummy freqs parameter for MusicFlamingo checkpoints.
|
||||
self.pos_emb = nn.Module()
|
||||
freqs = torch.empty(getattr(config, "num_mel_bins", 128))
|
||||
self.pos_emb.register_parameter(
|
||||
"freqs", nn.Parameter(freqs, requires_grad=False)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -21,6 +21,7 @@ from vllm.multimodal.processing import BaseProcessingInfo
|
||||
from .audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3ForConditionalGeneration,
|
||||
AudioFlamingo3MultiModalDataParser,
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
)
|
||||
|
||||
@@ -53,8 +54,16 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo):
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
return hf_processor.feature_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
return AudioFlamingo3MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
return {"audio": 1}
|
||||
|
||||
|
||||
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
|
||||
|
||||
Reference in New Issue
Block a user