diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 07e7da344..09051a37f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -535,7 +535,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | | ------------ | ------ | ------ | ----------------- | -------------------- | ------------------------- | | `AriaForConditionalGeneration` | Aria | T + I+ | `rhymes-ai/Aria` | | | -| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-2601-hf` | ✅︎ | ✅︎ | +| `AudioFlamingo3ForConditionalGeneration` | AudioFlamingo3 | T + A | `nvidia/audio-flamingo-3-hf`, `nvidia/music-flamingo-hf` | ✅︎ | ✅︎ | | `AyaVisionForConditionalGeneration` | Aya Vision | T + I+ | `CohereLabs/aya-vision-8b`, `CohereLabs/aya-vision-32b`, etc. | | ✅︎ | | `BagelForConditionalGeneration` | BAGEL | T + I+ | `ByteDance-Seed/BAGEL-7B-MoT` | ✅︎ | ✅︎ | | `BeeForConditionalGeneration` | Bee-8B | T + IE+ | `Open-Bee/Bee-8B-RL`, `Open-Bee/Bee-8B-SFT` | | ✅︎ | @@ -586,6 +586,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | | `Molmo2ForConditionalGeneration` | Molmo2 | T + I+ / V | `allenai/Molmo2-4B`, `allenai/Molmo2-8B`, `allenai/Molmo2-O-7B` | ✅︎ | ✅︎ | +| `MusicFlamingoForConditionalGeneration` | MusicFlamingo | T + A | `nvidia/music-flamingo-2601-hf`, `nvidia/music-flamingo-think-2601-hf` | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | | `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + IE+ | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ | | `OpenPanguVLForConditionalGeneration` | openpangu-VL | T + IE+ + VE+ | `FreedomIntelligence/openPangu-VL-7B` | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 780ddb90e..b7e49d2c9 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -104,12 +104,22 @@ def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData: enforce_eager=True, ) - # MusicFlamingo uses token for audio + # MusicFlamingo prompt placeholders use ; vLLM's MusicFlamingo + # multimodal processor expands each one into <|sound_bos|> + audio tokens + + # <|sound_eos|> based on extracted audio feature lengths. audio_placeholder = "" * audio_count + system_prompt = ( + "You are Music Flamingo, a multimodal assistant for language and music. " + "On each turn you receive an audio clip which contains music and optional " + "text, you will receive at least one or both; use your world knowledge and " + "reasoning to help the user with any task. Interpret the entirety of the " + "content any input music--regardlenss of whether the user calls it audio, " + "music, or sound." + ) prompt = ( "<|im_start|>system\n" - "You are a helpful assistant.<|im_end|>\n" + f"{system_prompt}<|im_end|>\n" "<|im_start|>user\n" f"{audio_placeholder}{question}<|im_end|>\n" "<|im_start|>assistant\n" diff --git a/tests/models/fixtures/audioflamingo3/expected_results_single.json b/tests/models/fixtures/audioflamingo3/expected_results_single.json index be9233467..1e54d3006 100644 --- a/tests/models/fixtures/audioflamingo3/expected_results_single.json +++ b/tests/models/fixtures/audioflamingo3/expected_results_single.json @@ -1 +1 @@ -{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]} \ No newline at end of file +{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other."], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645]]} diff --git a/tests/models/fixtures/musicflamingo/expected_results_batched.json b/tests/models/fixtures/musicflamingo/expected_results_batched.json new file mode 100644 index 000000000..797d9dafc --- /dev/null +++ b/tests/models/fixtures/musicflamingo/expected_results_batched.json @@ -0,0 +1 @@ +{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is ", "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on your skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the"], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220], [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 697, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279]]} diff --git a/tests/models/fixtures/musicflamingo/expected_results_single.json b/tests/models/fixtures/musicflamingo/expected_results_single.json new file mode 100644 index 000000000..99d4fb370 --- /dev/null +++ b/tests/models/fixtures/musicflamingo/expected_results_single.json @@ -0,0 +1 @@ +{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is "], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220]]} diff --git a/tests/models/multimodal/generation/test_audioflamingo3.py b/tests/models/multimodal/generation/test_audioflamingo3.py index d14291a62..187300e34 100644 --- a/tests/models/multimodal/generation/test_audioflamingo3.py +++ b/tests/models/multimodal/generation/test_audioflamingo3.py @@ -26,6 +26,54 @@ from tests.models.registry import HF_EXAMPLE_MODELS from vllm import LLM, SamplingParams MODEL_NAME = "nvidia/audio-flamingo-3-hf" +SINGLE_CONVERSATION = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What is surprising about the relationship between " + "the barking and the music?", + }, + { + "type": "audio_url", + "audio_url": { + "url": "https://huggingface.co/datasets/nvidia/AudioSkills/" + "resolve/main/assets/" + "dogs_barking_in_sync_with_the_music.wav", + }, + }, + ], + } +] +BATCHED_CONVERSATIONS = [ + SINGLE_CONVERSATION, + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Why is the philosopher's name mentioned in the " + "lyrics? (A) To express a sense of nostalgia " + "(B) To indicate that language cannot express clearly, " + "satirizing the inversion of black and white in the world " + "(C) To add depth and complexity to the lyrics " + "(D) To showcase the wisdom and influence of the " + "philosopher", + }, + { + "type": "audio_url", + "audio_url": { + "url": "https://huggingface.co/datasets/nvidia/" + "AudioSkills/resolve/main/assets/" + "Ch6Ae9DT6Ko_00-04-03_00-04-31.wav", + }, + }, + ], + } + ], +] def get_fixture_path(filename): @@ -34,21 +82,29 @@ def get_fixture_path(filename): ) +def assert_output_matches(output, expected_text, expected_token_ids): + generated = output.outputs[0] + assert generated.text.strip() == expected_text + actual_token_ids = list(generated.token_ids) + assert ( + actual_token_ids == expected_token_ids + or actual_token_ids == expected_token_ids[:-1] + or actual_token_ids[:-1] == expected_token_ids + ) + + @pytest.fixture(scope="module") def llm(): - # Check if the model is supported by the current transformers version model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") model_info.check_transformers_version(on_fail="skip") try: - llm = LLM( + return LLM( model=MODEL_NAME, - trust_remote_code=True, dtype="bfloat16", enforce_eager=True, limit_mm_per_prompt={"audio": 1}, ) - return llm except Exception as e: pytest.skip(f"Failed to load model {MODEL_NAME}: {e}") @@ -61,29 +117,17 @@ def test_single_generation(llm): with open(fixture_path) as f: expected = json.load(f) - audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav" - - messages = [ - { - "role": "user", - "content": [ - {"type": "audio_url", "audio_url": {"url": audio_url}}, - {"type": "text", "text": "Transcribe the input speech."}, - ], - } - ] - sampling_params = SamplingParams(temperature=0.0, max_tokens=128) outputs = llm.chat( - messages=messages, + messages=SINGLE_CONVERSATION, sampling_params=sampling_params, ) - generated_text = outputs[0].outputs[0].text.strip() - - expected_text = expected["transcriptions"][0] - - assert expected_text in generated_text or generated_text in expected_text + assert_output_matches( + outputs[0], + expected["transcriptions"][0], + expected["token_ids"][0], + ) def test_batched_generation(llm): @@ -94,49 +138,34 @@ def test_batched_generation(llm): with open(fixture_path) as f: expected = json.load(f) - items = [ - { - "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav", - "question": "What is surprising about the relationship " - "between the barking and the music?", - "expected_idx": 0, - }, - { - "audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav", - "question": ( - "Why is the philosopher's name mentioned in the lyrics? " - "(A) To express a sense of nostalgia " - "(B) To indicate that language cannot express clearly, " - "satirizing the inversion of black and white in the world " - "(C) To add depth and complexity to the lyrics " - "(D) To showcase the wisdom and influence of the philosopher" - ), - "expected_idx": 1, - }, - ] - - conversations = [] - for item in items: - messages = [ - { - "role": "user", - "content": [ - {"type": "audio_url", "audio_url": {"url": item["audio_url"]}}, - {"type": "text", "text": item["question"]}, - ], - } - ] - conversations.append(messages) - sampling_params = SamplingParams(temperature=0.0, max_tokens=128) outputs = llm.chat( - messages=conversations, + messages=BATCHED_CONVERSATIONS, sampling_params=sampling_params, ) for i, output in enumerate(outputs): - generated_text = output.outputs[0].text.strip() - expected_text = expected["transcriptions"][i] + assert_output_matches( + output, + expected["transcriptions"][i], + expected["token_ids"][i], + ) - assert expected_text in generated_text or generated_text in expected_text + +def test_single_and_batched_generation_match(llm): + sampling_params = SamplingParams(temperature=0.0, max_tokens=128) + + single_output = llm.chat( + messages=SINGLE_CONVERSATION, + sampling_params=sampling_params, + )[0] + batched_output = llm.chat( + messages=BATCHED_CONVERSATIONS, + sampling_params=sampling_params, + )[0] + + assert single_output.outputs[0].text == batched_output.outputs[0].text + assert list(single_output.outputs[0].token_ids) == list( + batched_output.outputs[0].token_ids + ) diff --git a/tests/models/multimodal/generation/test_musicflamingo.py b/tests/models/multimodal/generation/test_musicflamingo.py new file mode 100644 index 000000000..c87c46a7c --- /dev/null +++ b/tests/models/multimodal/generation/test_musicflamingo.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import os + +import pytest + +from tests.models.registry import HF_EXAMPLE_MODELS +from vllm import LLM, SamplingParams + +MODEL_NAME = "nvidia/music-flamingo-2601-hf" +SINGLE_CONVERSATION = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this track in full detail - tell me the " + "genre, tempo, and key, then dive into the instruments, " + "production style, and overall mood it creates.", + }, + { + "type": "audio_url", + "audio_url": { + "url": "https://huggingface.co/datasets/nvidia/AudioSkills/" + "resolve/main/assets/song_1.mp3", + }, + }, + ], + } +] +BATCHED_CONVERSATIONS = [ + SINGLE_CONVERSATION, + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Generate a structured lyric sheet from the input music.", + }, + { + "type": "audio_url", + "audio_url": { + "url": "https://huggingface.co/datasets/nvidia/" + "AudioSkills/resolve/main/assets/song_2.mp3", + }, + }, + ], + } + ], +] + + +def get_fixture_path(filename): + return os.path.join( + os.path.dirname(__file__), "../../fixtures/musicflamingo", filename + ) + + +def assert_output_matches(output, expected_text, expected_token_ids): + generated = output.outputs[0] + assert generated.text == expected_text + actual_token_ids = list(generated.token_ids) + assert ( + actual_token_ids == expected_token_ids + or actual_token_ids == expected_token_ids[:-1] + or actual_token_ids[:-1] == expected_token_ids + ) + + +@pytest.fixture(scope="module") +def llm(): + model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + try: + return LLM( + model=MODEL_NAME, + dtype="bfloat16", + enforce_eager=True, + max_model_len=8192, + limit_mm_per_prompt={"audio": 1}, + ) + except Exception as e: + pytest.skip(f"Failed to load model {MODEL_NAME}: {e}") + + +def test_single_generation(llm): + fixture_path = get_fixture_path("expected_results_single.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + outputs = llm.chat( + messages=SINGLE_CONVERSATION, + sampling_params=SamplingParams(temperature=0.0, max_tokens=50), + ) + + assert_output_matches( + outputs[0], + expected["transcriptions"][0], + expected["token_ids"][0], + ) + + +def test_batched_generation(llm): + fixture_path = get_fixture_path("expected_results_batched.json") + if not os.path.exists(fixture_path): + pytest.skip(f"Fixture not found: {fixture_path}") + + with open(fixture_path) as f: + expected = json.load(f) + + outputs = llm.chat( + messages=BATCHED_CONVERSATIONS, + sampling_params=SamplingParams(temperature=0.0, max_tokens=50), + ) + + for i, output in enumerate(outputs): + assert_output_matches( + output, + expected["transcriptions"][i], + expected["token_ids"][i], + ) + + +def test_single_and_batched_generation_match(llm): + sampling_params = SamplingParams(temperature=0.0, max_tokens=50) + + single_output = llm.chat( + messages=SINGLE_CONVERSATION, + sampling_params=sampling_params, + )[0] + batched_output = llm.chat( + messages=BATCHED_CONVERSATIONS, + sampling_params=sampling_params, + )[0] + + assert single_output.outputs[0].text == batched_output.outputs[0].text + assert list(single_output.outputs[0].token_ids) == list( + batched_output.outputs[0].token_ids + ) diff --git a/tests/models/multimodal/processing/test_audioflamingo3.py b/tests/models/multimodal/processing/test_audioflamingo3.py index 428fd9c6e..24311e521 100644 --- a/tests/models/multimodal/processing/test_audioflamingo3.py +++ b/tests/models/multimodal/processing/test_audioflamingo3.py @@ -40,6 +40,7 @@ class MockAudioFlamingo3Processor: def __init__(self): self.audio_token = "" self.audio_token_id = 12345 + self.max_audio_len = 60 self.feature_extractor = MockFeatureExtractor() def __call__(self, text=None, audios=None, **kwargs): @@ -65,7 +66,6 @@ def mock_ctx(): @pytest.fixture(autouse=True) def check_transformers_version(): - # Check if the model is supported by the current transformers version model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") model_info.check_transformers_version(on_fail="skip") @@ -84,7 +84,7 @@ def test_audio_chunk_counting(mock_ctx): sr = 16000 audio_1 = np.zeros(30 * sr) - audio_2 = np.zeros(45 * sr) + audio_2 = np.zeros(75 * sr) mm_data = {"audio": [audio_1, audio_2]} prompt = "<|user|>Listen.<|end|>" @@ -121,5 +121,107 @@ def test_dummy_data_generation(mock_ctx): assert "audio" in dummy_data assert len(dummy_data["audio"]) == 2 - expected_len = 600 * 16000 + expected_len = 60 * 16000 assert len(dummy_data["audio"][0]) == expected_len + + +def test_audio_token_count_matches_hf_processor_math(): + from vllm.model_executor.models.audioflamingo3 import ( + _count_audio_tokens_from_mask, + ) + + feature_attention_mask = torch.zeros((3, 3000), dtype=torch.long) + feature_attention_mask[0, :2999] = 1 + feature_attention_mask[1, :2999] = 1 + feature_attention_mask[2, :1500] = 1 + chunk_counts = torch.tensor([2, 1], dtype=torch.long) + + assert ( + _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 0) == 1499 + ) + assert _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 1) == 375 + + +def test_audio_feature_pipeline_matches_hf_small_config(): + from transformers.models.audioflamingo3 import ( + modeling_audioflamingo3 as hf_audioflamingo3_modeling, + ) + from transformers.models.audioflamingo3.configuration_audioflamingo3 import ( + AudioFlamingo3Config, + ) + + from vllm.model_executor.models.audioflamingo3 import ( + AudioFlamingo3Encoder, + AudioFlamingo3MultiModalProjector, + _build_audio_encoder_attention_mask, + _flatten_valid_audio_embeddings, + ) + + text_config = { + "model_type": "qwen2", + "intermediate_size": 64, + "initializer_range": 0.02, + "hidden_size": 32, + "max_position_embeddings": 1024, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "vocab_size": 128, + "pad_token_id": 1, + "use_mrope": False, + } + audio_config = { + "hidden_size": 16, + "num_attention_heads": 4, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_mel_bins": 80, + "max_source_positions": 1500, + "dropout": 0.0, + "attention_dropout": 0.0, + "activation_dropout": 0.0, + "encoder_layerdrop": 0.0, + } + + torch.manual_seed(0) + config = AudioFlamingo3Config( + text_config=text_config, + audio_config=audio_config, + audio_token_id=0, + ) + hf_model = hf_audioflamingo3_modeling.AudioFlamingo3ForConditionalGeneration( + config + ).eval() + + vllm_encoder = AudioFlamingo3Encoder(config.audio_config).eval() + vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict()) + + vllm_projector = AudioFlamingo3MultiModalProjector(config).eval() + vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict()) + + input_features = torch.randn(3, 80, 3000) + feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool) + feature_attention_mask[0, :3000] = True + feature_attention_mask[1, :2500] = True + feature_attention_mask[2, :1500] = True + + hf_output = hf_model.get_audio_features( + input_features, + feature_attention_mask, + return_dict=True, + ).pooler_output + vllm_attention_mask = _build_audio_encoder_attention_mask( + feature_attention_mask, + dtype=vllm_encoder.conv1.weight.dtype, + device=vllm_encoder.conv1.weight.device, + ) + vllm_hidden_states = vllm_encoder( + input_features, + attention_mask=vllm_attention_mask, + ) + vllm_output, _ = _flatten_valid_audio_embeddings( + vllm_projector(vllm_hidden_states), + feature_attention_mask, + ) + + torch.testing.assert_close(vllm_output, hf_output) diff --git a/tests/models/multimodal/processing/test_musicflamingo.py b/tests/models/multimodal/processing/test_musicflamingo.py new file mode 100644 index 000000000..625e1ad8d --- /dev/null +++ b/tests/models/multimodal/processing/test_musicflamingo.py @@ -0,0 +1,222 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2026 The vLLM team. +# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock + +import numpy as np +import pytest +import torch +from transformers import PretrainedConfig + +from tests.models.registry import HF_EXAMPLE_MODELS + + +class MockMusicFlamingoConfig(PretrainedConfig): + model_type = "musicflamingo" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.audio_config = PretrainedConfig() + self.text_config = PretrainedConfig() + + +class MockMusicFlamingoProcessor: + def __init__(self): + self.audio_token = "" + self.audio_token_id = 12345 + self.audio_bos_token = "<|sound_bos|>" + self.audio_bos_token_id = 12346 + self.audio_eos_token = "<|sound_eos|>" + self.audio_eos_token_id = 12347 + self.max_audio_len = 1200 + self.feature_extractor = MockFeatureExtractor() + + +class MockFeatureExtractor: + def __init__(self): + self.sampling_rate = 16000 + self.chunk_length = 30 + + +@pytest.fixture +def mock_ctx(): + config = MockMusicFlamingoConfig() + + ctx = MagicMock() + ctx.get_hf_config.return_value = config + ctx.get_hf_processor.return_value = MockMusicFlamingoProcessor() + ctx.model_config.hf_config = config + return ctx + + +@pytest.fixture(autouse=True) +def check_transformers_version(): + model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration") + model_info.check_transformers_version(on_fail="skip") + + +def test_musicflamingo_chunk_counting_uses_rote_timestamps(mock_ctx, monkeypatch): + from vllm.model_executor.models.musicflamingo import ( + MusicFlamingoDummyInputsBuilder, + MusicFlamingoMultiModalProcessor, + MusicFlamingoProcessingInfo, + ) + + info = MusicFlamingoProcessingInfo(mock_ctx) + processor = MusicFlamingoMultiModalProcessor( + info, MusicFlamingoDummyInputsBuilder(info) + ) + + sr = 16000 + audio_1 = np.zeros(30 * sr) + audio_2 = np.zeros(45 * sr) + + mm_data = {"audio": [audio_1, audio_2]} + prompt = "<|user|>Listen.<|end|>" + + from vllm.multimodal.processing import BaseMultiModalProcessor + + def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs): + del self, prompt, mm_data, mm_kwargs, tok_kwargs + return { + "input_ids": [1, 2, 3], + "input_features": torch.randn(3, 80, 3000), + "rote_timestamps": torch.randn(3, 750), + } + + monkeypatch.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call) + + processed = processor._call_hf_processor(prompt, mm_data, {}, {}) + + chunk_counts = processed["chunk_counts"] + + assert chunk_counts.tolist() == [1, 2] + assert "rote_timestamps" in processed + + +def test_musicflamingo_dummy_text_uses_plain_audio_tokens(mock_ctx): + from vllm.model_executor.models.musicflamingo import ( + MusicFlamingoDummyInputsBuilder, + MusicFlamingoProcessingInfo, + ) + + info = MusicFlamingoProcessingInfo(mock_ctx) + builder = MusicFlamingoDummyInputsBuilder(info) + + assert builder.get_dummy_text({"audio": 2}) == "" + + +def test_musicflamingo_audio_feature_pipeline_matches_hf_small_config(): + from transformers.models.musicflamingo import ( + modeling_musicflamingo as hf_musicflamingo_modeling, + ) + from transformers.models.musicflamingo.configuration_musicflamingo import ( + MusicFlamingoConfig, + ) + + from vllm.model_executor.models.audioflamingo3 import ( + _build_audio_encoder_attention_mask, + _flatten_valid_audio_embeddings, + ) + from vllm.model_executor.models.musicflamingo import ( + MusicFlamingoEncoder, + MusicFlamingoMultiModalProjector, + MusicFlamingoRotaryEmbedding, + apply_rotary_time_emb, + ) + + text_config = { + "model_type": "qwen2", + "intermediate_size": 64, + "initializer_range": 0.02, + "hidden_size": 32, + "max_position_embeddings": 1024, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "vocab_size": 128, + "pad_token_id": 1, + "use_mrope": False, + } + audio_config = { + "hidden_size": 16, + "num_attention_heads": 4, + "intermediate_size": 32, + "num_hidden_layers": 2, + "num_mel_bins": 80, + "max_source_positions": 1500, + "dropout": 0.0, + "attention_dropout": 0.0, + "activation_dropout": 0.0, + "encoder_layerdrop": 0.0, + } + + torch.manual_seed(0) + config = MusicFlamingoConfig( + text_config=text_config, + audio_config=audio_config, + audio_token_id=0, + head_dim=8, + rope_parameters={"rope_type": "default", "rope_theta": 2048}, + ) + hf_model = hf_musicflamingo_modeling.MusicFlamingoForConditionalGeneration( + config + ).eval() + + vllm_encoder = MusicFlamingoEncoder(config.audio_config).eval() + vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict()) + + vllm_projector = MusicFlamingoMultiModalProjector(config).eval() + vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict()) + + vllm_rope = MusicFlamingoRotaryEmbedding(config).eval() + vllm_rope.load_state_dict(hf_model.pos_emb.state_dict(), strict=False) + + input_features = torch.randn(3, 80, 3000) + feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool) + feature_attention_mask[0, :3000] = True + feature_attention_mask[1, :2500] = True + feature_attention_mask[2, :1500] = True + rote_timestamps = ( + torch.arange(750, dtype=torch.float32).unsqueeze(0).repeat(3, 1) * 0.04 + ) + + hf_output = hf_model.get_audio_features( + input_features, + feature_attention_mask, + rote_timestamps=rote_timestamps, + return_dict=True, + ).pooler_output + vllm_attention_mask = _build_audio_encoder_attention_mask( + feature_attention_mask, + dtype=vllm_encoder.conv1.weight.dtype, + device=vllm_encoder.conv1.weight.device, + ) + vllm_hidden_states = vllm_encoder( + input_features, + attention_mask=vllm_attention_mask, + ) + cos, sin = vllm_rope(rote_timestamps, seq_len=vllm_hidden_states.shape[-2]) + vllm_hidden_states = apply_rotary_time_emb(vllm_hidden_states, cos, sin) + vllm_output, _ = _flatten_valid_audio_embeddings( + vllm_projector(vllm_hidden_states), + feature_attention_mask, + ) + + torch.testing.assert_close(vllm_output, hf_output) diff --git a/tests/models/registry.py b/tests/models/registry.py index ff997706c..0d1e8e348 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -752,7 +752,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { "nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0" ), "MusicFlamingoForConditionalGeneration": _HfExamplesInfo( - "nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev" + "nvidia/music-flamingo-2601-hf", min_transformers_version="5.3.0" ), "AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"), "BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"), diff --git a/vllm/model_executor/models/audioflamingo3.py b/vllm/model_executor/models/audioflamingo3.py index 1a25dca2d..82906a6fa 100644 --- a/vllm/model_executor/models/audioflamingo3.py +++ b/vllm/model_executor/models/audioflamingo3.py @@ -69,10 +69,7 @@ from .utils import ( maybe_prefix, ) -MAX_AUDIO_LEN = 10 * 60 - -# === Audio Inputs === # class AudioFlamingo3FeatureInputs(TensorSchema): """ Dimensions: @@ -127,14 +124,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): ): super().__init__(config) self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2) - # self.layer_norm is already initialized in super().__init__ def forward( self, input_features: torch.Tensor | list[torch.Tensor], attention_mask: torch.Tensor = None, ): - # input_features: (batch, num_mel_bins, seq_len) if isinstance(input_features, list): input_features = torch.stack(input_features) @@ -146,17 +141,14 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder): ).to(hidden_states.dtype) for layer in self.layers: - # Qwen2AudioEncoderLayer expects layer_head_mask as third arg. - layer_outputs = layer(hidden_states, attention_mask, None) - hidden_states = layer_outputs[0] + layer_outputs = layer(hidden_states, attention_mask) + hidden_states = ( + layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs + ) - # AvgPool (time/2) + LayerNorm - # hidden_states: (batch, seq_len, hidden_size) - hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len) + hidden_states = hidden_states.permute(0, 2, 1) hidden_states = self.avg_pooler(hidden_states) - hidden_states = hidden_states.permute( - 0, 2, 1 - ) # (batch, seq_len/2, hidden_size) + hidden_states = hidden_states.permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) return hidden_states @@ -193,22 +185,6 @@ class AudioFlamingo3MultiModalProjector(nn.Module): return hidden_states -class AudioFlamingo3MultiModalDataParser(MultiModalDataParser): - def _parse_audio_data( - self, - data: dict[str, torch.Tensor] | ModalityData[Any], - ) -> ModalityDataItems[Any, Any] | None: - if isinstance(data, dict): - return DictEmbeddingItems( - data, - modality="audio", - required_fields={"audio_embeds"}, - fields_factory=_audioflamingo3_field_config, - ) - - return super()._parse_audio_data(data) - - class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(AudioFlamingo3Config) @@ -217,20 +193,17 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo): return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs) def get_feature_extractor(self, **kwargs: object): - hf_processor = self.get_hf_processor(**kwargs) - feature_extractor = hf_processor.feature_extractor - return feature_extractor + return self.get_hf_processor(**kwargs).feature_extractor - def get_data_parser(self): + def get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.get_feature_extractor() - return AudioFlamingo3MultiModalDataParser( target_sr=feature_extractor.sampling_rate, expected_hidden_size=self._get_expected_hidden_size(), ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": 1} + return {"audio": None} class AudioFlamingo3DummyInputsBuilder( @@ -248,9 +221,10 @@ class AudioFlamingo3DummyInputsBuilder( mm_counts: Mapping[str, int], mm_options: Mapping[str, BaseDummyOptions], ) -> MultiModalDataDict: + hf_processor = self.info.get_hf_processor() feature_extractor = self.info.get_feature_extractor() sampling_rate = feature_extractor.sampling_rate - audio_len = MAX_AUDIO_LEN * sampling_rate + audio_len = int(hf_processor.max_audio_len * sampling_rate) num_audios = mm_counts.get("audio", 0) audio_overrides = mm_options.get("audio") @@ -284,6 +258,118 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]): ) +def _get_audio_post_pool_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor: + conv_lengths = (input_lengths - 1) // 2 + 1 + return (conv_lengths - 2) // 2 + 1 + + +def _build_audio_encoder_attention_mask( + feature_attention_mask: torch.Tensor, + *, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + input_lengths = feature_attention_mask.sum(-1).to(torch.long) + conv_lengths = (input_lengths - 1) // 2 + 1 + + batch_size, max_mel_seq_len = feature_attention_mask.shape + max_seq_len = (max_mel_seq_len - 1) // 2 + 1 + + seq_range = ( + torch.arange( + max_seq_len, + dtype=conv_lengths.dtype, + device=conv_lengths.device, + ) + .unsqueeze(0) + .expand(batch_size, max_seq_len) + ) + padding_mask = seq_range >= conv_lengths[:, None] + + attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( + batch_size, 1, max_seq_len, max_seq_len + ) + attention_mask = attention_mask.to(dtype=dtype, device=device) + attention_mask.masked_fill_(padding_mask[:, None, None, :], float("-inf")) + + return attention_mask + + +def _flatten_valid_audio_embeddings( + audio_embeddings: torch.Tensor, + feature_attention_mask: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + input_lengths = feature_attention_mask.sum(-1).to(torch.long) + output_lengths = _get_audio_post_pool_output_lengths(input_lengths) + valid_mask = ( + torch.arange(audio_embeddings.shape[1], device=output_lengths.device)[None, :] + < output_lengths[:, None] + ) + + return audio_embeddings[valid_mask], output_lengths + + +def _count_audio_tokens_from_mask( + feature_attention_mask: torch.Tensor | list[torch.Tensor], + chunk_counts: torch.Tensor | list[torch.Tensor] | list[int] | None, + item_idx: int, +) -> int: + if chunk_counts is not None: + if isinstance(chunk_counts, torch.Tensor): + counts = chunk_counts.tolist() + elif chunk_counts and isinstance(chunk_counts[0], torch.Tensor): + counts = [count.item() for count in chunk_counts] + else: + counts = chunk_counts + + start_idx = sum(counts[:item_idx]) + count = counts[item_idx] + end_idx = start_idx + count + + if isinstance(feature_attention_mask, list): + sample_mask = feature_attention_mask[start_idx:end_idx] + if len(sample_mask) == 0: + raise ValueError("Expected non-empty audio mask slice.") + if isinstance(sample_mask[0], torch.Tensor): + sample_mask = torch.stack(sample_mask) + else: + sample_mask = torch.tensor(sample_mask) + else: + sample_mask = feature_attention_mask[start_idx:end_idx] + else: + if isinstance(feature_attention_mask, list): + sample_mask = feature_attention_mask[item_idx] + else: + sample_mask = feature_attention_mask[item_idx] + + if sample_mask.ndim == 1: + sample_input_lengths = sample_mask.sum().unsqueeze(0) + else: + # Match the HF processor, which derives placeholder lengths from the + # total pre-encoder feature length for each original audio sample. + sample_input_lengths = sample_mask.sum().reshape(1) + + post_lengths = _get_audio_post_pool_output_lengths( + sample_input_lengths.to(torch.long) + ) + return int(post_lengths[0].item()) + + +class AudioFlamingo3MultiModalDataParser(MultiModalDataParser): + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_audioflamingo3_field_config, + ) + return super()._parse_audio_data(data) + + class AudioFlamingo3MultiModalProcessor( BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo] ): @@ -303,13 +389,13 @@ class AudioFlamingo3MultiModalProcessor( prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") - feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + processor = self.info.get_hf_processor(**mm_kwargs) + feature_extractor = processor.feature_extractor mm_kwargs = dict( **mm_kwargs, sampling_rate=feature_extractor.sampling_rate, ) - # Calculate chunk counts audio_list = mm_data.get("audio") if not isinstance(audio_list, list): audio_list = [audio_list] @@ -318,8 +404,7 @@ class AudioFlamingo3MultiModalProcessor( sampling_rate = feature_extractor.sampling_rate chunk_length = feature_extractor.chunk_length window_size = int(sampling_rate * chunk_length) - # MAX_AUDIO_LEN is 10 * 60 in HF processor. - max_windows = int(MAX_AUDIO_LEN // chunk_length) + max_windows = int(processor.max_audio_len // chunk_length) for audio in audio_list: # audio is numpy array or list @@ -364,7 +449,6 @@ class AudioFlamingo3MultiModalProcessor( audio_token = getattr(processor, "audio_token", "") audio_token_id = vocab.get(audio_token) if audio_token_id is None: - # Fallback if not found, though it should be there audio_token_id = processor.audio_token_id out_mm_data = out_mm_kwargs.get_data() @@ -373,38 +457,11 @@ class AudioFlamingo3MultiModalProcessor( def get_replacement_audioflamingo3(item_idx: int): if feature_attention_mask is not None: - if chunk_counts is not None: - counts = ( - chunk_counts.tolist() - if isinstance(chunk_counts, torch.Tensor) - else chunk_counts - ) - start_idx = sum(counts[:item_idx]) - count = counts[item_idx] - end_idx = start_idx + count - - if isinstance(feature_attention_mask, list): - mask_list = feature_attention_mask[start_idx:end_idx] - if len(mask_list) > 0 and isinstance( - mask_list[0], torch.Tensor - ): - mask = torch.stack(mask_list) - else: - mask = torch.tensor(mask_list) - else: - mask = feature_attention_mask[start_idx:end_idx] - else: - # feature_attention_mask is list[Tensor] or Tensor - if isinstance(feature_attention_mask, list): - mask = feature_attention_mask[item_idx] - else: - mask = feature_attention_mask[item_idx].unsqueeze(0) - - # mask shape: (num_chunks, 3000) - input_lengths = mask.sum(-1) - conv_lengths = (input_lengths - 1) // 2 + 1 - audio_output_lengths = (conv_lengths - 2) // 2 + 1 - num_features = audio_output_lengths.sum().item() + num_features = _count_audio_tokens_from_mask( + feature_attention_mask, + chunk_counts, + item_idx, + ) else: audio_embeds = out_mm_data["audio_embeds"][item_idx] num_features = audio_embeds.shape[0] @@ -435,13 +492,6 @@ class AudioFlamingo3MultiModalProcessor( class AudioFlamingo3ForConditionalGeneration( nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA ): - """ - AudioFlamingo3 model for conditional generation. - - This model integrates a Whisper-based audio encoder with a Qwen2 language model. - It supports multi-chunk audio processing. - """ - packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], @@ -517,6 +567,25 @@ class AudioFlamingo3ForConditionalGeneration( audio_embeds = audio_input["audio_embeds"] return tuple(audio_embeds) + ( + input_features, + feature_attention_mask, + chunk_counts, + ) = self._normalize_audio_feature_inputs(audio_input) + audio_hidden_states = self._encode_audio_features( + input_features, + feature_attention_mask, + ) + audio_features = self.multi_modal_projector(audio_hidden_states) + return self._group_audio_embeddings( + audio_features, + feature_attention_mask, + chunk_counts, + ) + + def _normalize_audio_feature_inputs( + self, audio_input: AudioFlamingo3FeatureInputs + ) -> tuple[torch.Tensor, torch.Tensor, list[int]]: input_features = audio_input["input_features"] feature_attention_mask = audio_input["feature_attention_mask"] chunk_counts = audio_input.get("chunk_counts") @@ -534,66 +603,36 @@ class AudioFlamingo3ForConditionalGeneration( and chunk_counts and isinstance(chunk_counts[0], torch.Tensor) ): - chunk_counts = [c.item() for c in chunk_counts] + chunk_counts = [count.item() for count in chunk_counts] - # Calculate output lengths - input_lengths = feature_attention_mask.sum(-1) - # Conv downsampling - conv_lengths = (input_lengths - 1) // 2 + 1 - # AvgPool downsampling - audio_output_lengths = (conv_lengths - 2) // 2 + 1 + return input_features, feature_attention_mask, chunk_counts - batch_size, _, max_mel_seq_len = input_features.shape - - # Calculate max_seq_len after convs (before pooling) for attention mask - max_seq_len = (max_mel_seq_len - 1) // 2 + 1 - - # Create a sequence tensor of shape (batch_size, max_seq_len) - seq_range = ( - torch.arange( - 0, - max_seq_len, - dtype=conv_lengths.dtype, - device=conv_lengths.device, - ) - .unsqueeze(0) - .expand(batch_size, max_seq_len) - ) - lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len) - # Create mask - padding_mask = seq_range >= lengths_expand - - audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand( - batch_size, 1, max_seq_len, max_seq_len - ) - audio_attention_mask = audio_attention_mask_.to( + def _encode_audio_features( + self, + input_features: torch.Tensor, + feature_attention_mask: torch.Tensor, + ) -> torch.Tensor: + audio_attention_mask = _build_audio_encoder_attention_mask( + feature_attention_mask, dtype=self.audio_tower.conv1.weight.dtype, device=self.audio_tower.conv1.weight.device, ) - audio_attention_mask[audio_attention_mask_] = float("-inf") - # Forward pass - audio_features = self.audio_tower( - input_features, attention_mask=audio_attention_mask + return self.audio_tower(input_features, attention_mask=audio_attention_mask) + + def _group_audio_embeddings( + self, + audio_features: torch.Tensor, + feature_attention_mask: torch.Tensor, + chunk_counts: list[int], + ) -> tuple[torch.Tensor, ...]: + masked_audio_features, audio_output_lengths = _flatten_valid_audio_embeddings( + audio_features, + feature_attention_mask, ) - - # Project - audio_features = self.multi_modal_projector(audio_features) - - # Masking after pooling - num_audios, max_audio_tokens, embed_dim = audio_features.shape - audio_output_lengths = audio_output_lengths.unsqueeze(1) - audio_features_mask = ( - torch.arange(max_audio_tokens) - .expand(num_audios, max_audio_tokens) - .to(audio_output_lengths.device) - < audio_output_lengths - ) - masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim) - - # Split to tuple of embeddings for individual audio input. chunk_embeddings = torch.split( - masked_audio_features, audio_output_lengths.flatten().tolist() + masked_audio_features, + audio_output_lengths.tolist(), ) grouped_embeddings = [] @@ -613,7 +652,7 @@ class AudioFlamingo3ForConditionalGeneration( def forward( self, - input_ids: torch.Tensor | None, + input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, diff --git a/vllm/model_executor/models/musicflamingo.py b/vllm/model_executor/models/musicflamingo.py index 84328d4cd..f4e3bbe37 100644 --- a/vllm/model_executor/models/musicflamingo.py +++ b/vllm/model_executor/models/musicflamingo.py @@ -1,63 +1,209 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""MusicFlamingo model adapter. +# Copyright 2026 The vLLM team. +# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same -implementation and multimodal processor, while accepting MusicFlamingo config -and processor classes when available. -""" +from collections.abc import Callable, Mapping, Sequence +from math import pi +from typing import Annotated, Any, Optional, TypeAlias -from collections.abc import Mapping - -from transformers.models.audioflamingo3 import ( - AudioFlamingo3Config, - AudioFlamingo3Processor, +import torch +from torch import Tensor, broadcast_tensors, nn +from transformers import BatchFeature +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.musicflamingo import ( + MusicFlamingoConfig, + MusicFlamingoProcessor, ) +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.processing import BaseProcessingInfo +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + DictEmbeddingItems, + ModalityData, + ModalityDataItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.utils.tensor_schema import TensorShape from .audioflamingo3 import ( AudioFlamingo3DummyInputsBuilder, + AudioFlamingo3EmbeddingInputs, + AudioFlamingo3Encoder, + AudioFlamingo3FeatureInputs, AudioFlamingo3ForConditionalGeneration, AudioFlamingo3MultiModalDataParser, AudioFlamingo3MultiModalProcessor, + AudioFlamingo3MultiModalProjector, + AudioFlamingo3ProcessingInfo, + _audioflamingo3_field_config, + _count_audio_tokens_from_mask, ) -try: - # Optional dependency: use MusicFlamingo classes when transformers provides them. - from transformers.models.musicflamingo import ( - MusicFlamingoConfig, - MusicFlamingoProcessor, - ) -except Exception: # pragma: no cover - optional dependency - MusicFlamingoConfig = None - MusicFlamingoProcessor = None + +def rotate_half(x): + x = x.reshape(*x.shape[:-1], -1, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) -class MusicFlamingoProcessingInfo(BaseProcessingInfo): - def get_hf_config(self): - if MusicFlamingoConfig is None: - return self.ctx.get_hf_config(AudioFlamingo3Config) - return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config)) - - def get_hf_processor(self, **kwargs: object): - if MusicFlamingoProcessor is None: - return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs) - # Tuple triggers AutoProcessor path and accepts either processor class. - return self.ctx.get_hf_processor( - (MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs +def apply_rotary_time_emb(hidden_states, cos, sin): + original_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float64) + cos = cos.to(hidden_states) + sin = sin.to(hidden_states) + rot_dim = cos.shape[-1] + if rot_dim > hidden_states.shape[-1]: + raise ValueError( + f"feature dimension {hidden_states.shape[-1]} is not of " + f"sufficient size to rotate in all the positions {rot_dim}" ) - def get_feature_extractor(self, **kwargs: object): - hf_processor = self.get_hf_processor(**kwargs) - return hf_processor.feature_extractor + rotated = hidden_states[..., :rot_dim] + passthrough = hidden_states[..., rot_dim:] + rotated = (rotated * cos) + (rotate_half(rotated) * sin) + return torch.cat((rotated, passthrough), dim=-1).to(original_dtype) - def get_data_parser(self): + +class MusicFlamingoRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, config: MusicFlamingoConfig, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + position_angles = self._compute_position_angles(self.inv_freq) + self.register_buffer("position_angles", position_angles, persistent=False) + + @staticmethod + def compute_default_rope_parameters( + config: MusicFlamingoConfig | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + del seq_len + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or ( + config.hidden_size // config.num_attention_heads + ) + attention_factor = 1.0 + + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, + dtype=torch.float, + ) + / dim + ) + ) + return inv_freq, attention_factor + + def _compute_position_angles(self, inv_freq): + positions = torch.arange( + int(self.max_seq_len_cached), + device=inv_freq.device, + dtype=inv_freq.dtype, + ) + positions = positions / self.max_seq_len_cached * (2 * pi) + position_angles = positions.unsqueeze(-1) * inv_freq + position_angles = torch.repeat_interleave(position_angles, 2, dim=-1) + return position_angles.to(dtype=inv_freq.dtype) + + @torch.no_grad() + def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]: + batch_positions = torch.arange( + timestamps.shape[0], + device=self.inv_freq.device, + dtype=self.inv_freq.dtype, + ) + batch_positions = batch_positions / self.max_seq_len_cached + batch_freqs = batch_positions.unsqueeze(-1) * self.inv_freq + batch_freqs = torch.repeat_interleave(batch_freqs, 2, dim=-1) + + batch_freqs = batch_freqs[:, None, :] + time_freqs = self.position_angles[:seq_len][None, :, :] + batch_freqs, time_freqs = broadcast_tensors(batch_freqs, time_freqs) + freqs = torch.cat((batch_freqs, time_freqs), dim=-1) + angle = (-timestamps * 2 * pi).to(freqs) + freqs = freqs * angle.unsqueeze(-1) + return freqs.cos(), freqs.sin() + + +class MusicFlamingoFeatureInputs(AudioFlamingo3FeatureInputs): + rote_timestamps: Annotated[ + torch.Tensor, + TensorShape( + "num_chunks", + "num_audio_time_steps", + dynamic_dims={"num_audio_time_steps"}, + ), + ] + + +MusicFlamingoEmbeddingInputs = AudioFlamingo3EmbeddingInputs + +MusicFlamingoInputs: TypeAlias = ( + MusicFlamingoFeatureInputs | MusicFlamingoEmbeddingInputs +) + + +class MusicFlamingoEncoder(AudioFlamingo3Encoder): + pass + + +class MusicFlamingoMultiModalProjector(AudioFlamingo3MultiModalProjector): + pass + + +class MusicFlamingoProcessingInfo(AudioFlamingo3ProcessingInfo): + def get_hf_config(self) -> MusicFlamingoConfig: + return self.ctx.get_hf_config(MusicFlamingoConfig) + + def get_hf_processor(self, **kwargs: object) -> MusicFlamingoProcessor: + return self.ctx.get_hf_processor(MusicFlamingoProcessor, **kwargs) + + def get_data_parser(self) -> MultiModalDataParser: feature_extractor = self.get_feature_extractor() - - return AudioFlamingo3MultiModalDataParser( + return MusicFlamingoMultiModalDataParser( target_sr=feature_extractor.sampling_rate, expected_hidden_size=self._get_expected_hidden_size(), ) @@ -67,13 +213,230 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo): class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder): - pass + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + hf_processor = self.info.get_hf_processor() + return hf_processor.audio_token * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + hf_processor = self.info.get_hf_processor() + feature_extractor = self.info.get_feature_extractor() + sampling_rate = feature_extractor.sampling_rate + audio_len = int(hf_processor.max_audio_len * sampling_rate) + num_audios = mm_counts.get("audio", 0) + audio_overrides = mm_options.get("audio") + + return { + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +def _musicflamingo_field_config(hf_inputs: Mapping[str, torch.Tensor]): + fields = dict(_audioflamingo3_field_config(hf_inputs)) + chunk_counts = hf_inputs.get("chunk_counts") + if chunk_counts is not None: + fields["rote_timestamps"] = MultiModalFieldConfig.flat_from_sizes( + "audio", chunk_counts, dim=0 + ) + else: + fields["rote_timestamps"] = MultiModalFieldConfig.batched("audio") + return fields + + +class MusicFlamingoMultiModalDataParser(AudioFlamingo3MultiModalDataParser): + def _parse_audio_data( + self, + data: dict[str, torch.Tensor] | ModalityData[Any], + ) -> ModalityDataItems[Any, Any] | None: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="audio", + required_fields={"audio_embeds"}, + fields_factory=_musicflamingo_field_config, + ) + return super()._parse_audio_data(data) + + +class MusicFlamingoMultiModalProcessor(AudioFlamingo3MultiModalProcessor): + def _call_hf_processor( + self, + prompt: str, + mm_data: dict[str, object], + mm_kwargs: Mapping[str, Any], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + audio_data = mm_data.get("audio") + if audio_data is None: + return outputs + + audio_list = audio_data if isinstance(audio_data, list) else [audio_data] + if len(audio_list) == 0: + return outputs + + processor = self.info.get_hf_processor(**mm_kwargs) + feature_extractor = processor.feature_extractor + sampling_rate = feature_extractor.sampling_rate + chunk_length = feature_extractor.chunk_length + window_size = int(sampling_rate * chunk_length) + max_windows = int(processor.max_audio_len // chunk_length) + + chunk_counts = [] + for audio in audio_list: + n_samples = len(audio) if isinstance(audio, list) else audio.shape[0] + n_win = max(1, (n_samples + window_size - 1) // window_size) + chunk_counts.append(min(n_win, max_windows)) + outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long) + + if "rote_timestamps" not in outputs: + raise KeyError( + "MusicFlamingoProcessor output must include `rote_timestamps`." + ) + + return outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _musicflamingo_field_config(hf_inputs) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = processor.audio_token + audio_token_id = vocab.get(audio_token, processor.audio_token_id) + + audio_bos_token = processor.audio_bos_token + audio_bos_token_id = vocab.get(audio_bos_token, processor.audio_bos_token_id) + + audio_eos_token = processor.audio_eos_token + audio_eos_token_id = vocab.get(audio_eos_token, processor.audio_eos_token_id) + + out_mm_data = out_mm_kwargs.get_data() + feature_attention_mask = out_mm_data.get("feature_attention_mask") + chunk_counts = out_mm_data.get("chunk_counts") + + def get_replacement_musicflamingo(item_idx: int): + if feature_attention_mask is not None: + num_features = _count_audio_tokens_from_mask( + feature_attention_mask, + chunk_counts, + item_idx, + ) + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + num_features = audio_embeds.shape[0] + + if num_features == 0: + raise ValueError("Audio is too short") + + full_tokens = [ + audio_bos_token_id, + *([audio_token_id] * int(num_features)), + audio_eos_token_id, + ] + + return PromptUpdateDetails.select_token_id( + full_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_musicflamingo, + ) + ] @MULTIMODAL_REGISTRY.register_processor( - AudioFlamingo3MultiModalProcessor, + MusicFlamingoMultiModalProcessor, info=MusicFlamingoProcessingInfo, dummy_inputs=MusicFlamingoDummyInputsBuilder, ) class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): - """MusicFlamingo model for conditional generation.""" + """vLLM MusicFlamingo model aligned with HF modular_musicflamingo.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.audio_tower = MusicFlamingoEncoder(self.config.audio_config) + self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config) + self.pos_emb = MusicFlamingoRotaryEmbedding(self.config) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> MusicFlamingoInputs | None: + rote_timestamps = kwargs.pop("rote_timestamps", None) + audio_input = super()._parse_and_validate_audio_input(**kwargs) + if audio_input is None or audio_input["type"] == "audio_embeds": + return audio_input + + return MusicFlamingoFeatureInputs( + type="audio_features", + input_features=audio_input["input_features"], + feature_attention_mask=audio_input["feature_attention_mask"], + chunk_counts=audio_input["chunk_counts"], + rote_timestamps=rote_timestamps, + ) + + def _process_audio_input( + self, audio_input: MusicFlamingoInputs + ) -> torch.Tensor | tuple[torch.Tensor, ...]: + if audio_input["type"] == "audio_embeds": + return super()._process_audio_input(audio_input) + + rote_timestamps = audio_input["rote_timestamps"] + if rote_timestamps is None: + raise ValueError( + "MusicFlamingo audio feature inputs must include `rote_timestamps`." + ) + if isinstance(rote_timestamps, list): + rote_timestamps = torch.cat(rote_timestamps, dim=0) + + ( + input_features, + feature_attention_mask, + chunk_counts, + ) = self._normalize_audio_feature_inputs(audio_input) + hidden_states = self._encode_audio_features( + input_features, + feature_attention_mask, + ) + cos, sin = self.pos_emb( + rote_timestamps.to(hidden_states.device), + seq_len=hidden_states.shape[-2], + ) + hidden_states = apply_rotary_time_emb(hidden_states, cos, sin) + audio_features = self.multi_modal_projector(hidden_states) + + return self._group_audio_embeddings( + audio_features, + feature_attention_mask, + chunk_counts, + )