diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 9c38887bb..444d238c5 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -657,7 +657,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) |
|--------------|--------|--------|-------------------|----------------------|---------------------------|
| `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | |
-| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A+ | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ |
+| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A+ | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ |
| `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ |
| `BagelForConditionalGeneration` | BAGEL | T + I+ | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ |
| `BeeForConditionalGeneration` | Bee-8B | T + IE+ | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ |
diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py
index a8f70c5b9..4bf4b4e1d 100755
--- a/examples/offline_inference/audio_language.py
+++ b/examples/offline_inference/audio_language.py
@@ -70,6 +70,34 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
)
+# MusicFlamingo
+def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
+ model_name = "nvidia/music-flamingo-2601-hf"
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=2,
+ limit_mm_per_prompt={"audio": audio_count},
+ enforce_eager=True,
+ )
+
+ # MusicFlamingo uses token for audio
+ audio_placeholder = "" * audio_count
+
+ prompt = (
+ "<|im_start|>system\n"
+ "You are a helpful assistant.<|im_end|>\n"
+ "<|im_start|>user\n"
+ f"{audio_placeholder}{question}<|im_end|>\n"
+ "<|im_start|>assistant\n"
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ )
+
+
# Gemma3N
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
model_name = "google/gemma-3n-E2B-it"
@@ -452,6 +480,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"audioflamingo3": run_audioflamingo3,
+ "musicflamingo": run_musicflamingo,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
diff --git a/tests/models/registry.py b/tests/models/registry.py
index cc031cc74..82f0254b6 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -657,6 +657,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"AudioFlamingo3ForConditionalGeneration": _HfExamplesInfo(
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
),
+ "MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
+ "nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev"
+ ),
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
"BeeForConditionalGeneration": _HfExamplesInfo(
diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py
index 3f1661abe..fa4f93b86 100644
--- a/vllm/model_executor/models/audioflamingo3.py
+++ b/vllm/model_executor/models/audioflamingo3.py
@@ -128,6 +128,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
super().__init__(config)
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
# self.layer_norm is already initialized in super().__init__
+ # Keep a dummy freqs parameter for MusicFlamingo checkpoints.
+ self.pos_emb = nn.Module()
+ freqs = torch.empty(getattr(config, "num_mel_bins", 128))
+ self.pos_emb.register_parameter(
+ "freqs", nn.Parameter(freqs, requires_grad=False)
+ )
def forward(
self,
@@ -146,7 +152,8 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
).to(hidden_states.dtype)
for layer in self.layers:
- layer_outputs = layer(hidden_states, attention_mask)
+ # Qwen2AudioEncoderLayer expects layer_head_mask as third arg.
+ layer_outputs = layer(hidden_states, attention_mask, None)
hidden_states = layer_outputs[0]
# AvgPool (time/2) + LayerNorm
diff --git a/vllm/model_executor/models/musicflamingo.py b/vllm/model_executor/models/musicflamingo.py
new file mode 100644
index 000000000..161de4e24
--- /dev/null
+++ b/vllm/model_executor/models/musicflamingo.py
@@ -0,0 +1,70 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""MusicFlamingo model adapter.
+
+MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same
+implementation and multimodal processor, while accepting MusicFlamingo config
+and processor classes when available.
+"""
+
+from collections.abc import Mapping
+
+from transformers.models.audioflamingo3 import (
+ AudioFlamingo3Config,
+ AudioFlamingo3Processor,
+)
+
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.processing import BaseProcessingInfo
+
+from .audioflamingo3 import (
+ AudioFlamingo3DummyInputsBuilder,
+ AudioFlamingo3ForConditionalGeneration,
+ AudioFlamingo3MultiModalProcessor,
+)
+
+try:
+ # Optional dependency: use MusicFlamingo classes when transformers provides them.
+ from transformers.models.musicflamingo import (
+ MusicFlamingoConfig,
+ MusicFlamingoProcessor,
+ )
+except Exception: # pragma: no cover - optional dependency
+ MusicFlamingoConfig = None
+ MusicFlamingoProcessor = None
+
+
+class MusicFlamingoProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self):
+ if MusicFlamingoConfig is None:
+ return self.ctx.get_hf_config(AudioFlamingo3Config)
+ return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config))
+
+ def get_hf_processor(self, **kwargs: object):
+ if MusicFlamingoProcessor is None:
+ return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
+ # Tuple triggers AutoProcessor path and accepts either processor class.
+ return self.ctx.get_hf_processor(
+ (MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs
+ )
+
+ def get_feature_extractor(self, **kwargs: object):
+ hf_processor = self.get_hf_processor(**kwargs)
+ return hf_processor.feature_extractor
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"audio": None}
+
+
+class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
+ pass
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ AudioFlamingo3MultiModalProcessor,
+ info=MusicFlamingoProcessingInfo,
+ dummy_inputs=MusicFlamingoDummyInputsBuilder,
+)
+class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
+ """MusicFlamingo model for conditional generation."""
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 8b0085205..4373e3a42 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -286,6 +286,10 @@ _MULTIMODAL_MODELS = {
"audioflamingo3",
"AudioFlamingo3ForConditionalGeneration",
),
+ "MusicFlamingoForConditionalGeneration": (
+ "musicflamingo",
+ "MusicFlamingoForConditionalGeneration",
+ ),
"AyaVisionForConditionalGeneration": (
"aya_vision",
"AyaVisionForConditionalGeneration",