# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2025 The vLLM team. # Copyright 2025 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights # reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import MagicMock import numpy as np import pytest import torch from transformers import PretrainedConfig from tests.models.registry import HF_EXAMPLE_MODELS class MockAudioFlamingo3Config(PretrainedConfig): model_type = "audioflamingo3" def __init__(self, **kwargs): super().__init__(**kwargs) self.audio_config = PretrainedConfig() self.text_config = PretrainedConfig() class MockAudioFlamingo3Processor: def __init__(self): self.audio_token = "" self.audio_token_id = 12345 self.max_audio_len = 60 self.feature_extractor = MockFeatureExtractor() def __call__(self, text=None, audios=None, **kwargs): return {"input_ids": [1, 2, 3], "input_features": [np.zeros((3000, 80))]} class MockFeatureExtractor: def __init__(self): self.sampling_rate = 16000 self.chunk_length = 30 @pytest.fixture def mock_ctx(): config = MockAudioFlamingo3Config() ctx = MagicMock() ctx.get_hf_config.return_value = config ctx.get_hf_processor.return_value = MockAudioFlamingo3Processor() ctx.model_config.hf_config = config return ctx @pytest.fixture(autouse=True) def check_transformers_version(): model_info = HF_EXAMPLE_MODELS.get_hf_info("AudioFlamingo3ForConditionalGeneration") model_info.check_transformers_version(on_fail="skip") def test_audio_chunk_counting(mock_ctx): from vllm.model_executor.models.audioflamingo3 import ( AudioFlamingo3DummyInputsBuilder, AudioFlamingo3MultiModalProcessor, AudioFlamingo3ProcessingInfo, ) info = AudioFlamingo3ProcessingInfo(mock_ctx) processor = AudioFlamingo3MultiModalProcessor( info, AudioFlamingo3DummyInputsBuilder(info) ) sr = 16000 audio_1 = np.zeros(30 * sr) audio_2 = np.zeros(75 * sr) mm_data = {"audio": [audio_1, audio_2]} prompt = "<|user|>Listen.<|end|>" from vllm.multimodal.processing import BaseMultiModalProcessor def mock_base_call(self, prompt, mm_data, mm_kwargs, tok_kwargs): return {"input_ids": [1, 2, 3], "input_features": torch.randn(1, 80, 3000)} with pytest.MonkeyPatch.context() as mp: mp.setattr(BaseMultiModalProcessor, "_call_hf_processor", mock_base_call) processed = processor._call_hf_processor(prompt, mm_data, {}, {}) chunk_counts = processed["chunk_counts"] assert chunk_counts[0].item() == 1 assert chunk_counts[1].item() == 2 assert len(chunk_counts) == 2 def test_dummy_data_generation(mock_ctx): from vllm.model_executor.models.audioflamingo3 import ( AudioFlamingo3DummyInputsBuilder, AudioFlamingo3ProcessingInfo, ) info = AudioFlamingo3ProcessingInfo(mock_ctx) builder = AudioFlamingo3DummyInputsBuilder(info) mm_counts = {"audio": 2} dummy_data = builder.get_dummy_mm_data(100, mm_counts, {}) assert "audio" in dummy_data assert len(dummy_data["audio"]) == 2 expected_len = 60 * 16000 assert len(dummy_data["audio"][0]) == expected_len def test_audio_token_count_matches_hf_processor_math(): from vllm.model_executor.models.audioflamingo3 import ( _count_audio_tokens_from_mask, ) feature_attention_mask = torch.zeros((3, 3000), dtype=torch.long) feature_attention_mask[0, :2999] = 1 feature_attention_mask[1, :2999] = 1 feature_attention_mask[2, :1500] = 1 chunk_counts = torch.tensor([2, 1], dtype=torch.long) assert ( _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 0) == 1499 ) assert _count_audio_tokens_from_mask(feature_attention_mask, chunk_counts, 1) == 375 def test_audio_feature_pipeline_matches_hf_small_config(): from transformers.models.audioflamingo3 import ( modeling_audioflamingo3 as hf_audioflamingo3_modeling, ) from transformers.models.audioflamingo3.configuration_audioflamingo3 import ( AudioFlamingo3Config, ) from vllm.model_executor.models.audioflamingo3 import ( AudioFlamingo3Encoder, AudioFlamingo3MultiModalProjector, _build_audio_encoder_attention_mask, _flatten_valid_audio_embeddings, ) text_config = { "model_type": "qwen2", "intermediate_size": 64, "initializer_range": 0.02, "hidden_size": 32, "max_position_embeddings": 1024, "num_hidden_layers": 2, "num_attention_heads": 4, "num_key_value_heads": 2, "vocab_size": 128, "pad_token_id": 1, "use_mrope": False, } audio_config = { "hidden_size": 16, "num_attention_heads": 4, "intermediate_size": 32, "num_hidden_layers": 2, "num_mel_bins": 80, "max_source_positions": 1500, "dropout": 0.0, "attention_dropout": 0.0, "activation_dropout": 0.0, "encoder_layerdrop": 0.0, } torch.manual_seed(0) config = AudioFlamingo3Config( text_config=text_config, audio_config=audio_config, audio_token_id=0, ) hf_model = hf_audioflamingo3_modeling.AudioFlamingo3ForConditionalGeneration( config ).eval() vllm_encoder = AudioFlamingo3Encoder(config.audio_config).eval() vllm_encoder.load_state_dict(hf_model.audio_tower.state_dict()) vllm_projector = AudioFlamingo3MultiModalProjector(config).eval() vllm_projector.load_state_dict(hf_model.multi_modal_projector.state_dict()) input_features = torch.randn(3, 80, 3000) feature_attention_mask = torch.zeros(3, 3000, dtype=torch.bool) feature_attention_mask[0, :3000] = True feature_attention_mask[1, :2500] = True feature_attention_mask[2, :1500] = True hf_output = hf_model.get_audio_features( input_features, feature_attention_mask, return_dict=True, ).pooler_output vllm_attention_mask = _build_audio_encoder_attention_mask( feature_attention_mask, dtype=vllm_encoder.conv1.weight.dtype, device=vllm_encoder.conv1.weight.device, ) vllm_hidden_states = vllm_encoder( input_features, attention_mask=vllm_attention_mask, ) vllm_output, _ = _flatten_valid_audio_embeddings( vllm_projector(vllm_hidden_states), feature_attention_mask, ) torch.testing.assert_close(vllm_output, hf_output)