Fix AudioFlamingo3/MusicFlamingo HF parity and RoTE handling (#37643)
Signed-off-by: Lasha <26011196+lashahub@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
43877a620b
commit
e7767eccae
@@ -1 +1 @@
|
||||
{"transcriptions": ["The content of the input audio is 'you can ask why over and over and over again forever even if one day we explain every physical interaction and scientific law and hope and dream and regret with a single elegant equation'."], "token_ids": [[785, 2213, 315, 279, 1946, 7699, 374, 364, 9330, 646, 2548, 3170, 916, 323, 916, 323, 916, 1549, 15683, 1496, 421, 825, 1899, 582, 10339, 1449, 6961, 16230, 323, 12344, 2329, 323, 3900, 323, 7904, 323, 22231, 448, 264, 3175, 25777, 23606, 4427, 151645]]}
|
||||
{"transcriptions": ["There is no clear relationship between the barking and the music, as they seem to be independent of each other."], "token_ids": [[3862, 374, 902, 2797, 5025, 1948, 279, 293, 33452, 323, 279, 4627, 11, 438, 807, 2803, 311, 387, 9489, 315, 1817, 1008, 13, 151645]]}
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is ", "**Verse 1**\nMidnight cravings in bloom, lights flicker in the room, pepperoni dreams arise, pizza party on your skies\n\n**Verse 2**\nCheese melts on the crust, in flavor we trust, boxes stacked to the"], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220], [334, 68043, 220, 16, 1019, 33648, 9287, 88828, 304, 51454, 11, 12711, 28347, 261, 304, 279, 3054, 11, 24353, 20783, 18707, 30789, 11, 22502, 4614, 389, 697, 49293, 271, 334, 68043, 220, 17, 1019, 26843, 2367, 98091, 389, 279, 39612, 11, 304, 17172, 582, 6950, 11, 14697, 41315, 311, 279]]}
|
||||
@@ -0,0 +1 @@
|
||||
{"transcriptions": ["This track is an energetic Eurodance / Dance‑Pop anthem that blends the bright, melodic sensibilities of mainstream pop with the driving, club‑ready pulse of classic Eurodance. The duration of the piece is "], "token_ids": [[1986, 3754, 374, 458, 44855, 19461, 98875, 378, 107, 14, 378, 107, 35, 681, 55964, 11598, 55564, 429, 57843, 279, 9906, 11, 10581, 52760, 6097, 13450, 315, 20729, 2420, 448, 279, 9842, 11, 6335, 55964, 2307, 27235, 315, 11416, 19461, 98875, 13, 220, 576, 8090, 315, 279, 6573, 374, 220]]}
|
||||
@@ -26,6 +26,54 @@ from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL_NAME = "nvidia/audio-flamingo-3-hf"
|
||||
SINGLE_CONVERSATION = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What is surprising about the relationship between "
|
||||
"the barking and the music?",
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": "https://huggingface.co/datasets/nvidia/AudioSkills/"
|
||||
"resolve/main/assets/"
|
||||
"dogs_barking_in_sync_with_the_music.wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
BATCHED_CONVERSATIONS = [
|
||||
SINGLE_CONVERSATION,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Why is the philosopher's name mentioned in the "
|
||||
"lyrics? (A) To express a sense of nostalgia "
|
||||
"(B) To indicate that language cannot express clearly, "
|
||||
"satirizing the inversion of black and white in the world "
|
||||
"(C) To add depth and complexity to the lyrics "
|
||||
"(D) To showcase the wisdom and influence of the "
|
||||
"philosopher",
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": "https://huggingface.co/datasets/nvidia/"
|
||||
"AudioSkills/resolve/main/assets/"
|
||||
"Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_fixture_path(filename):
|
||||
@@ -34,21 +82,29 @@ def get_fixture_path(filename):
|
||||
)
|
||||
|
||||
|
||||
def assert_output_matches(output, expected_text, expected_token_ids):
|
||||
generated = output.outputs[0]
|
||||
assert generated.text.strip() == expected_text
|
||||
actual_token_ids = list(generated.token_ids)
|
||||
assert (
|
||||
actual_token_ids == expected_token_ids
|
||||
or actual_token_ids == expected_token_ids[:-1]
|
||||
or actual_token_ids[:-1] == expected_token_ids
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
try:
|
||||
llm = LLM(
|
||||
return LLM(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
)
|
||||
return llm
|
||||
except Exception as e:
|
||||
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
|
||||
|
||||
@@ -61,29 +117,17 @@ def test_single_generation(llm):
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
audio_url = "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Why_do_we_ask_questions_converted.wav"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": audio_url}},
|
||||
{"type": "text", "text": "Transcribe the input speech."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=messages,
|
||||
messages=SINGLE_CONVERSATION,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
generated_text = outputs[0].outputs[0].text.strip()
|
||||
|
||||
expected_text = expected["transcriptions"][0]
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
assert_output_matches(
|
||||
outputs[0],
|
||||
expected["transcriptions"][0],
|
||||
expected["token_ids"][0],
|
||||
)
|
||||
|
||||
|
||||
def test_batched_generation(llm):
|
||||
@@ -94,49 +138,34 @@ def test_batched_generation(llm):
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
items = [
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/dogs_barking_in_sync_with_the_music.wav",
|
||||
"question": "What is surprising about the relationship "
|
||||
"between the barking and the music?",
|
||||
"expected_idx": 0,
|
||||
},
|
||||
{
|
||||
"audio_url": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/Ch6Ae9DT6Ko_00-04-03_00-04-31.wav",
|
||||
"question": (
|
||||
"Why is the philosopher's name mentioned in the lyrics? "
|
||||
"(A) To express a sense of nostalgia "
|
||||
"(B) To indicate that language cannot express clearly, "
|
||||
"satirizing the inversion of black and white in the world "
|
||||
"(C) To add depth and complexity to the lyrics "
|
||||
"(D) To showcase the wisdom and influence of the philosopher"
|
||||
),
|
||||
"expected_idx": 1,
|
||||
},
|
||||
]
|
||||
|
||||
conversations = []
|
||||
for item in items:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": item["audio_url"]}},
|
||||
{"type": "text", "text": item["question"]},
|
||||
],
|
||||
}
|
||||
]
|
||||
conversations.append(messages)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=conversations,
|
||||
messages=BATCHED_CONVERSATIONS,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
expected_text = expected["transcriptions"][i]
|
||||
assert_output_matches(
|
||||
output,
|
||||
expected["transcriptions"][i],
|
||||
expected["token_ids"][i],
|
||||
)
|
||||
|
||||
assert expected_text in generated_text or generated_text in expected_text
|
||||
|
||||
def test_single_and_batched_generation_match(llm):
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
|
||||
|
||||
single_output = llm.chat(
|
||||
messages=SINGLE_CONVERSATION,
|
||||
sampling_params=sampling_params,
|
||||
)[0]
|
||||
batched_output = llm.chat(
|
||||
messages=BATCHED_CONVERSATIONS,
|
||||
sampling_params=sampling_params,
|
||||
)[0]
|
||||
|
||||
assert single_output.outputs[0].text == batched_output.outputs[0].text
|
||||
assert list(single_output.outputs[0].token_ids) == list(
|
||||
batched_output.outputs[0].token_ids
|
||||
)
|
||||
|
||||
146
tests/models/multimodal/generation/test_musicflamingo.py
Normal file
146
tests/models/multimodal/generation/test_musicflamingo.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODEL_NAME = "nvidia/music-flamingo-2601-hf"
|
||||
SINGLE_CONVERSATION = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe this track in full detail - tell me the "
|
||||
"genre, tempo, and key, then dive into the instruments, "
|
||||
"production style, and overall mood it creates.",
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": "https://huggingface.co/datasets/nvidia/AudioSkills/"
|
||||
"resolve/main/assets/song_1.mp3",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
BATCHED_CONVERSATIONS = [
|
||||
SINGLE_CONVERSATION,
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Generate a structured lyric sheet from the input music.",
|
||||
},
|
||||
{
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": "https://huggingface.co/datasets/nvidia/"
|
||||
"AudioSkills/resolve/main/assets/song_2.mp3",
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
def get_fixture_path(filename):
|
||||
return os.path.join(
|
||||
os.path.dirname(__file__), "../../fixtures/musicflamingo", filename
|
||||
)
|
||||
|
||||
|
||||
def assert_output_matches(output, expected_text, expected_token_ids):
|
||||
generated = output.outputs[0]
|
||||
assert generated.text == expected_text
|
||||
actual_token_ids = list(generated.token_ids)
|
||||
assert (
|
||||
actual_token_ids == expected_token_ids
|
||||
or actual_token_ids == expected_token_ids[:-1]
|
||||
or actual_token_ids[:-1] == expected_token_ids
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
try:
|
||||
return LLM(
|
||||
model=MODEL_NAME,
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True,
|
||||
max_model_len=8192,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.skip(f"Failed to load model {MODEL_NAME}: {e}")
|
||||
|
||||
|
||||
def test_single_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_single.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=SINGLE_CONVERSATION,
|
||||
sampling_params=SamplingParams(temperature=0.0, max_tokens=50),
|
||||
)
|
||||
|
||||
assert_output_matches(
|
||||
outputs[0],
|
||||
expected["transcriptions"][0],
|
||||
expected["token_ids"][0],
|
||||
)
|
||||
|
||||
|
||||
def test_batched_generation(llm):
|
||||
fixture_path = get_fixture_path("expected_results_batched.json")
|
||||
if not os.path.exists(fixture_path):
|
||||
pytest.skip(f"Fixture not found: {fixture_path}")
|
||||
|
||||
with open(fixture_path) as f:
|
||||
expected = json.load(f)
|
||||
|
||||
outputs = llm.chat(
|
||||
messages=BATCHED_CONVERSATIONS,
|
||||
sampling_params=SamplingParams(temperature=0.0, max_tokens=50),
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
assert_output_matches(
|
||||
output,
|
||||
expected["transcriptions"][i],
|
||||
expected["token_ids"][i],
|
||||
)
|
||||
|
||||
|
||||
def test_single_and_batched_generation_match(llm):
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=50)
|
||||
|
||||
single_output = llm.chat(
|
||||
messages=SINGLE_CONVERSATION,
|
||||
sampling_params=sampling_params,
|
||||
)[0]
|
||||
batched_output = llm.chat(
|
||||
messages=BATCHED_CONVERSATIONS,
|
||||
sampling_params=sampling_params,
|
||||
)[0]
|
||||
|
||||
assert single_output.outputs[0].text == batched_output.outputs[0].text
|
||||
assert list(single_output.outputs[0].token_ids) == list(
|
||||
batched_output.outputs[0].token_ids
|
||||
)
|
||||
@@ -40,6 +40,7 @@ class MockAudioFlamingo3Processor:
|
||||
def __init__(self):
|
||||
self.audio_token = "<sound>"
|
||||
self.audio_token_id = 12345
|
||||
self.max_audio_len = 60
|
||||
self.feature_extractor = MockFeatureExtractor()
|
||||
|
||||
def __call__(self, text=None, audios=None, **kwargs):
|
||||
@@ -65,7 +66,6 @@ def mock_ctx():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_transformers_version():
|
||||
# Check if the model is supported by the current transformers version
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
@@ -84,7 +84,7 @@ def test_audio_chunk_counting(mock_ctx):
|
||||
|
||||
sr = 16000
|
||||
audio_1 = np.zeros(30 * sr)
|
||||
audio_2 = np.zeros(45 * sr)
|
||||
audio_2 = np.zeros(75 * sr)
|
||||
|
||||
mm_data = {"audio": [audio_1, audio_2]}
|
||||
prompt = "<|user|>Listen.<|end|>"
|
||||
@@ -121,5 +121,107 @@ def test_dummy_data_generation(mock_ctx):
|
||||
assert "audio" in dummy_data
|
||||
assert len(dummy_data["audio"]) == 2
|
||||
|
||||
expected_len = 600 * 16000
|
||||
expected_len = 60 * 16000
|
||||
assert len(dummy_data["audio"][0]) == expected_len
|
||||
|
||||
|
||||
def test_audio_token_count_matches_hf_processor_math():
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
_count_audio_tokens_from_mask,
|
||||
)
|
||||
|
||||
feature_attention_mask = torch.zeros((3, 3000), dtype=torch.long)
|
||||
feature_attention_mask[0, :2999] = 1
|
||||
feature_attention_mask[1, :2999] = 1
|
||||
feature_attention_mask[2, :1500] = 1
|
||||
chunk_counts = torch.tensor([2, 1], dtype=torch.long)
|
||||
|
||||
assert (
|
||||
_count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 0) == 1499
|
||||
)
|
||||
assert _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 1) == 375
|
||||
|
||||
|
||||
def test_audio_feature_pipeline_matches_hf_small_config():
|
||||
from transformers.models.audioflamingo3 import (
|
||||
modeling_audioflamingo3 as hf_audioflamingo3_modeling,
|
||||
)
|
||||
from transformers.models.audioflamingo3.configuration_audioflamingo3 import (
|
||||
AudioFlamingo3Config,
|
||||
)
|
||||
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
AudioFlamingo3Encoder,
|
||||
AudioFlamingo3MultiModalProjector,
|
||||
_build_audio_encoder_attention_mask,
|
||||
_flatten_valid_audio_embeddings,
|
||||
)
|
||||
|
||||
text_config = {
|
||||
"model_type": "qwen2",
|
||||
"intermediate_size": 64,
|
||||
"initializer_range": 0.02,
|
||||
"hidden_size": 32,
|
||||
"max_position_embeddings": 1024,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"vocab_size": 128,
|
||||
"pad_token_id": 1,
|
||||
"use_mrope": False,
|
||||
}
|
||||
audio_config = {
|
||||
"hidden_size": 16,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_mel_bins": 80,
|
||||
"max_source_positions": 1500,
|
||||
"dropout": 0.0,
|
||||
"attention_dropout": 0.0,
|
||||
"activation_dropout": 0.0,
|
||||
"encoder_layerdrop": 0.0,
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = AudioFlamingo3Config(
|
||||
text_config=text_config,
|
||||
audio_config=audio_config,
|
||||
audio_token_id=0,
|
||||
)
|
||||
hf_model = hf_audioflamingo3_modeling.AudioFlamingo3ForConditionalGeneration(
|
||||
config
|
||||
).eval()
|
||||
|
||||
vllm_encoder = AudioFlamingo3Encoder(config.audio_config).eval()
|
||||
vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict())
|
||||
|
||||
vllm_projector = AudioFlamingo3MultiModalProjector(config).eval()
|
||||
vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict())
|
||||
|
||||
input_features = torch.randn(3, 80, 3000)
|
||||
feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool)
|
||||
feature_attention_mask[0, :3000] = True
|
||||
feature_attention_mask[1, :2500] = True
|
||||
feature_attention_mask[2, :1500] = True
|
||||
|
||||
hf_output = hf_model.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
return_dict=True,
|
||||
).pooler_output
|
||||
vllm_attention_mask = _build_audio_encoder_attention_mask(
|
||||
feature_attention_mask,
|
||||
dtype=vllm_encoder.conv1.weight.dtype,
|
||||
device=vllm_encoder.conv1.weight.device,
|
||||
)
|
||||
vllm_hidden_states = vllm_encoder(
|
||||
input_features,
|
||||
attention_mask=vllm_attention_mask,
|
||||
)
|
||||
vllm_output, _ = _flatten_valid_audio_embeddings(
|
||||
vllm_projector(vllm_hidden_states),
|
||||
feature_attention_mask,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(vllm_output, hf_output)
|
||||
|
||||
222
tests/models/multimodal/processing/test_musicflamingo.py
Normal file
222
tests/models/multimodal/processing/test_musicflamingo.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2026 The vLLM team.
|
||||
# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tests.models.registry import HF_EXAMPLE_MODELS
|
||||
|
||||
|
||||
class MockMusicFlamingoConfig(PretrainedConfig):
|
||||
model_type = "musicflamingo"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.audio_config = PretrainedConfig()
|
||||
self.text_config = PretrainedConfig()
|
||||
|
||||
|
||||
class MockMusicFlamingoProcessor:
|
||||
def __init__(self):
|
||||
self.audio_token = "<sound>"
|
||||
self.audio_token_id = 12345
|
||||
self.audio_bos_token = "<|sound_bos|>"
|
||||
self.audio_bos_token_id = 12346
|
||||
self.audio_eos_token = "<|sound_eos|>"
|
||||
self.audio_eos_token_id = 12347
|
||||
self.max_audio_len = 1200
|
||||
self.feature_extractor = MockFeatureExtractor()
|
||||
|
||||
|
||||
class MockFeatureExtractor:
|
||||
def __init__(self):
|
||||
self.sampling_rate = 16000
|
||||
self.chunk_length = 30
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ctx():
|
||||
config = MockMusicFlamingoConfig()
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.get_hf_config.return_value = config
|
||||
ctx.get_hf_processor.return_value = MockMusicFlamingoProcessor()
|
||||
ctx.model_config.hf_config = config
|
||||
return ctx
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_transformers_version():
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info("MusicFlamingoForConditionalGeneration")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
def test_musicflamingo_chunk_counting_uses_rote_timestamps(mock_ctx, monkeypatch):
|
||||
from vllm.model_executor.models.musicflamingo import (
|
||||
MusicFlamingoDummyInputsBuilder,
|
||||
MusicFlamingoMultiModalProcessor,
|
||||
MusicFlamingoProcessingInfo,
|
||||
)
|
||||
|
||||
info = MusicFlamingoProcessingInfo(mock_ctx)
|
||||
processor = MusicFlamingoMultiModalProcessor(
|
||||
info, MusicFlamingoDummyInputsBuilder(info)
|
||||
)
|
||||
|
||||
sr = 16000
|
||||
audio_1 = np.zeros(30 * sr)
|
||||
audio_2 = np.zeros(45 * sr)
|
||||
|
||||
mm_data = {"audio": [audio_1, audio_2]}
|
||||
prompt = "<|user|>Listen.<|end|>"
|
||||
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs):
|
||||
del self, prompt, mm_data, mm_kwargs, tok_kwargs
|
||||
return {
|
||||
"input_ids": [1, 2, 3],
|
||||
"input_features": torch.randn(3, 80, 3000),
|
||||
"rote_timestamps": torch.randn(3, 750),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call)
|
||||
|
||||
processed = processor._call_hf_processor(prompt, mm_data, {}, {})
|
||||
|
||||
chunk_counts = processed["chunk_counts"]
|
||||
|
||||
assert chunk_counts.tolist() == [1, 2]
|
||||
assert "rote_timestamps" in processed
|
||||
|
||||
|
||||
def test_musicflamingo_dummy_text_uses_plain_audio_tokens(mock_ctx):
|
||||
from vllm.model_executor.models.musicflamingo import (
|
||||
MusicFlamingoDummyInputsBuilder,
|
||||
MusicFlamingoProcessingInfo,
|
||||
)
|
||||
|
||||
info = MusicFlamingoProcessingInfo(mock_ctx)
|
||||
builder = MusicFlamingoDummyInputsBuilder(info)
|
||||
|
||||
assert builder.get_dummy_text({"audio": 2}) == "<sound><sound>"
|
||||
|
||||
|
||||
def test_musicflamingo_audio_feature_pipeline_matches_hf_small_config():
|
||||
from transformers.models.musicflamingo import (
|
||||
modeling_musicflamingo as hf_musicflamingo_modeling,
|
||||
)
|
||||
from transformers.models.musicflamingo.configuration_musicflamingo import (
|
||||
MusicFlamingoConfig,
|
||||
)
|
||||
|
||||
from vllm.model_executor.models.audioflamingo3 import (
|
||||
_build_audio_encoder_attention_mask,
|
||||
_flatten_valid_audio_embeddings,
|
||||
)
|
||||
from vllm.model_executor.models.musicflamingo import (
|
||||
MusicFlamingoEncoder,
|
||||
MusicFlamingoMultiModalProjector,
|
||||
MusicFlamingoRotaryEmbedding,
|
||||
apply_rotary_time_emb,
|
||||
)
|
||||
|
||||
text_config = {
|
||||
"model_type": "qwen2",
|
||||
"intermediate_size": 64,
|
||||
"initializer_range": 0.02,
|
||||
"hidden_size": 32,
|
||||
"max_position_embeddings": 1024,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 4,
|
||||
"num_key_value_heads": 2,
|
||||
"vocab_size": 128,
|
||||
"pad_token_id": 1,
|
||||
"use_mrope": False,
|
||||
}
|
||||
audio_config = {
|
||||
"hidden_size": 16,
|
||||
"num_attention_heads": 4,
|
||||
"intermediate_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_mel_bins": 80,
|
||||
"max_source_positions": 1500,
|
||||
"dropout": 0.0,
|
||||
"attention_dropout": 0.0,
|
||||
"activation_dropout": 0.0,
|
||||
"encoder_layerdrop": 0.0,
|
||||
}
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = MusicFlamingoConfig(
|
||||
text_config=text_config,
|
||||
audio_config=audio_config,
|
||||
audio_token_id=0,
|
||||
head_dim=8,
|
||||
rope_parameters={"rope_type": "default", "rope_theta": 2048},
|
||||
)
|
||||
hf_model = hf_musicflamingo_modeling.MusicFlamingoForConditionalGeneration(
|
||||
config
|
||||
).eval()
|
||||
|
||||
vllm_encoder = MusicFlamingoEncoder(config.audio_config).eval()
|
||||
vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict())
|
||||
|
||||
vllm_projector = MusicFlamingoMultiModalProjector(config).eval()
|
||||
vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict())
|
||||
|
||||
vllm_rope = MusicFlamingoRotaryEmbedding(config).eval()
|
||||
vllm_rope.load_state_dict(hf_model.pos_emb.state_dict(), strict=False)
|
||||
|
||||
input_features = torch.randn(3, 80, 3000)
|
||||
feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool)
|
||||
feature_attention_mask[0, :3000] = True
|
||||
feature_attention_mask[1, :2500] = True
|
||||
feature_attention_mask[2, :1500] = True
|
||||
rote_timestamps = (
|
||||
torch.arange(750, dtype=torch.float32).unsqueeze(0).repeat(3, 1) * 0.04
|
||||
)
|
||||
|
||||
hf_output = hf_model.get_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
rote_timestamps=rote_timestamps,
|
||||
return_dict=True,
|
||||
).pooler_output
|
||||
vllm_attention_mask = _build_audio_encoder_attention_mask(
|
||||
feature_attention_mask,
|
||||
dtype=vllm_encoder.conv1.weight.dtype,
|
||||
device=vllm_encoder.conv1.weight.device,
|
||||
)
|
||||
vllm_hidden_states = vllm_encoder(
|
||||
input_features,
|
||||
attention_mask=vllm_attention_mask,
|
||||
)
|
||||
cos, sin = vllm_rope(rote_timestamps, seq_len=vllm_hidden_states.shape[-2])
|
||||
vllm_hidden_states = apply_rotary_time_emb(vllm_hidden_states, cos, sin)
|
||||
vllm_output, _ = _flatten_valid_audio_embeddings(
|
||||
vllm_projector(vllm_hidden_states),
|
||||
feature_attention_mask,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(vllm_output, hf_output)
|
||||
@@ -752,7 +752,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"nvidia/audio-flamingo-3-hf", min_transformers_version="5.0.0"
|
||||
),
|
||||
"MusicFlamingoForConditionalGeneration": _HfExamplesInfo(
|
||||
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.0.0.dev"
|
||||
"nvidia/music-flamingo-2601-hf", min_transformers_version="5.3.0"
|
||||
),
|
||||
"AyaVisionForConditionalGeneration": _HfExamplesInfo("CohereLabs/aya-vision-8b"),
|
||||
"BagelForConditionalGeneration": _HfExamplesInfo("ByteDance-Seed/BAGEL-7B-MoT"),
|
||||
|
||||
Reference in New Issue
Block a user