2025-02-02 14:58:18 -05:00
# SPDX-License-Identifier: Apache-2.0
2025-06-03 11:20:17 -07:00
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2024-08-21 15:49:39 -07:00
"""
2025-05-26 17:57:54 +01:00
This example shows how to use vLLM for running offline inference
2024-08-22 10:02:06 -07:00
with the correct prompt format on audio language models .
2024-08-21 15:49:39 -07:00
For most models , the prompt format should follow corresponding examples
on HuggingFace model repository .
"""
2025-05-26 17:57:54 +01:00
2025-03-08 01:28:52 +08:00
import os
2025-07-15 16:35:30 +02:00
from typing import Any , NamedTuple
2025-03-08 01:28:52 +08:00
from huggingface_hub import snapshot_download
2024-08-21 15:49:39 -07:00
from transformers import AutoTokenizer
2025-03-17 18:00:17 +08:00
from vllm import LLM , EngineArgs , SamplingParams
2024-08-21 15:49:39 -07:00
from vllm . assets . audio import AudioAsset
2025-03-08 01:28:52 +08:00
from vllm . lora . request import LoRARequest
2025-10-26 16:33:32 +05:30
from vllm . utils . argparse_utils import FlexibleArgumentParser
2024-08-21 15:49:39 -07:00
2024-09-03 21:38:21 -07:00
audio_assets = [ AudioAsset ( " mary_had_lamb " ) , AudioAsset ( " winning_call " ) ]
2024-10-24 01:54:22 +08:00
question_per_audio_count = {
0 : " What is 1+1? " ,
1 : " What is recited in the audio? " ,
2025-05-26 17:57:54 +01:00
2 : " What sport and what nursery rhyme are referenced? " ,
2024-10-24 01:54:22 +08:00
}
2024-08-21 15:49:39 -07:00
2025-03-17 18:00:17 +08:00
class ModelRequestData ( NamedTuple ) :
engine_args : EngineArgs
2025-07-15 16:35:30 +02:00
prompt : str | None = None
prompt_token_ids : dict [ str , list [ int ] ] | None = None
multi_modal_data : dict [ str , Any ] | None = None
2025-03-17 18:00:17 +08:00
stop_token_ids : list [ int ] | None = None
lora_requests : list [ LoRARequest ] | None = None
2024-12-19 14:14:17 +08:00
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
2024-08-21 15:49:39 -07:00
2025-12-14 05:14:55 -05:00
# AudioFlamingo3
def run_audioflamingo3 ( question : str , audio_count : int ) - > ModelRequestData :
model_name = " nvidia/audio-flamingo-3-hf "
2025-07-15 16:35:30 +02:00
engine_args = EngineArgs (
model = model_name ,
2025-12-14 05:14:55 -05:00
max_model_len = 4096 ,
2025-07-15 16:35:30 +02:00
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
enforce_eager = True ,
)
2025-12-14 05:14:55 -05:00
# AudioFlamingo3 uses <sound> token for audio
audio_placeholder = " <sound> " * audio_count
2025-07-15 16:35:30 +02:00
2025-12-14 05:14:55 -05:00
prompt = (
" <|im_start|>system \n "
" You are a helpful assistant.<|im_end|> \n "
" <|im_start|>user \n "
f " { audio_placeholder } { question } <|im_end|> \n "
" <|im_start|>assistant \n "
)
2025-07-15 16:35:30 +02:00
return ModelRequestData (
engine_args = engine_args ,
2025-12-14 05:14:55 -05:00
prompt = prompt ,
2025-07-15 16:35:30 +02:00
)
2026-03-17 17:04:17 -04:00
# CohereASR
def run_cohere_asr ( question : str , audio_count : int ) - > ModelRequestData :
assert audio_count == 1 , " CohereASR only support single audio input per prompt "
2026-03-25 19:13:51 -04:00
model_name = " CohereLabs/cohere-transcribe-03-2026 "
2026-03-17 17:04:17 -04:00
prompt = (
" <|startofcontext|><|startoftranscript|> "
" <|emo:undefined|><|en|><|en|><|pnc|><|noitn|> "
" <|notimestamp|><|nodiarize|> "
)
engine_args = EngineArgs (
model = model_name ,
limit_mm_per_prompt = { " audio " : audio_count } ,
trust_remote_code = True ,
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2026-01-30 11:01:29 +08:00
# MusicFlamingo
def run_musicflamingo ( question : str , audio_count : int ) - > ModelRequestData :
model_name = " nvidia/music-flamingo-2601-hf "
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
enforce_eager = True ,
)
2026-03-22 22:29:07 -04:00
# MusicFlamingo prompt placeholders use <sound>; vLLM's MusicFlamingo
# multimodal processor expands each one into <|sound_bos|> + audio tokens +
# <|sound_eos|> based on extracted audio feature lengths.
2026-01-30 11:01:29 +08:00
audio_placeholder = " <sound> " * audio_count
2026-03-22 22:29:07 -04:00
system_prompt = (
" You are Music Flamingo, a multimodal assistant for language and music. "
" On each turn you receive an audio clip which contains music and optional "
" text, you will receive at least one or both; use your world knowledge and "
" reasoning to help the user with any task. Interpret the entirety of the "
" content any input music--regardlenss of whether the user calls it audio, "
" music, or sound. "
)
2026-01-30 11:01:29 +08:00
prompt = (
" <|im_start|>system \n "
2026-03-22 22:29:07 -04:00
f " { system_prompt } <|im_end|> \n "
2026-01-30 11:01:29 +08:00
" <|im_start|>user \n "
f " { audio_placeholder } { question } <|im_end|> \n "
" <|im_start|>assistant \n "
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-08-09 18:56:25 +02:00
# Gemma3N
def run_gemma3n ( question : str , audio_count : int ) - > ModelRequestData :
model_name = " google/gemma-3n-E2B-it "
engine_args = EngineArgs (
model = model_name ,
max_model_len = 2048 ,
max_num_batched_tokens = 2048 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
enforce_eager = True ,
)
prompt = f " <start_of_turn>user \n <audio_soft_token> { question } "
" <end_of_turn> \n <start_of_turn>model \n "
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2026-01-18 19:17:59 +08:00
# GLM-ASR
def run_glmasr ( question : str , audio_count : int ) - > ModelRequestData :
model_name = " zai-org/GLM-ASR-Nano-2512 "
tokenizer = AutoTokenizer . from_pretrained ( model_name , trust_remote_code = True )
# GLM-ASR uses <|pad|> token for audio
audio_placeholder = " <|pad|> " * audio_count
messages = [ { " role " : " user " , " content " : f " { audio_placeholder } { question } " } ]
prompt = tokenizer . apply_chat_template (
messages , tokenize = False , add_generation_prompt = True
)
engine_args = EngineArgs (
model = model_name ,
trust_remote_code = True ,
max_model_len = 4096 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2026-01-28 13:18:09 +08:00
# FunAudioChat
def run_funaudiochat ( question : str , audio_count : int ) - > ModelRequestData :
# NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of
# writing. Pass a local model path via `--model`.
model_name = " funaudiochat "
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
enforce_eager = True ,
)
audio_in_prompt = " " . join (
[ " <|audio_bos|><|AUDIO|><|audio_eos|> \n " for _ in range ( audio_count ) ]
)
prompt = f " { audio_in_prompt } { question } "
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-04-28 04:05:00 -06:00
# Granite Speech
def run_granite_speech ( question : str , audio_count : int ) - > ModelRequestData :
2025-09-04 17:52:17 +02:00
# NOTE - the setting in this example are somewhat different from what is
2025-04-28 04:05:00 -06:00
# optimal for granite speech, and it is generally recommended to use beam
# search. Check the model README for suggested settings.
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
model_name = " ibm-granite/granite-speech-3.3-8b "
engine_args = EngineArgs (
model = model_name ,
trust_remote_code = True ,
max_model_len = 2048 ,
max_num_seqs = 2 ,
enable_lora = True ,
max_lora_rank = 64 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
# The model has an audio-specific lora directly in its model dir;
# it should be enabled whenever you pass audio inputs to the model.
speech_lora_path = model_name
audio_placeholder = " <|audio|> " * audio_count
prompts = f " <|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024. \n Today ' s Date: December 19, 2024. \n You are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|> \n <|start_of_role|>user<|end_of_role|> { audio_placeholder } { question } <|end_of_text|> \n <|start_of_role|>assistant<|end_of_role|> " # noqa: E501
return ModelRequestData (
engine_args = engine_args ,
prompt = prompts ,
lora_requests = [ LoRARequest ( " speech " , 1 , speech_lora_path ) ] ,
)
2026-03-11 12:24:48 +08:00
# Kimi-Audio-7B-Instruct
def run_kimi_audio ( question : str , audio_count : int ) - > ModelRequestData :
""" Kimi-Audio-7B-Instruct for audio transcription and understanding. """
model_name = " moonshotai/Kimi-Audio-7B-Instruct "
engine_args = EngineArgs (
model = model_name ,
trust_remote_code = True ,
max_model_len = 4096 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
# Kimi-Audio uses <|im_kimia_text_blank|> as placeholder for audio features
audio_placeholder = " <|im_kimia_text_blank|> " * audio_count
# Default prompt for transcription
if not question :
question = " Please transcribe the audio "
prompt = f " { audio_placeholder } { question } "
# Stop at EOS token (151644) to prevent repetition
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
stop_token_ids = [ 151644 ] ,
)
2025-09-04 15:08:09 +08:00
# MiDashengLM
def run_midashenglm ( question : str , audio_count : int ) :
model_name = " mispeech/midashenglm-7b "
engine_args = EngineArgs (
model = model_name ,
trust_remote_code = True ,
max_model_len = 4096 ,
max_num_seqs = 5 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
audio_in_prompt = " " . join (
[ " <|audio_bos|><|AUDIO|><|audio_eos|> " for idx in range ( audio_count ) ]
)
default_system = " You are a helpful language and speech assistant. "
prompt = (
f " <|im_start|>system \n { default_system } <|im_end|> \n "
" <|im_start|>user \n "
f " { audio_in_prompt } { question } <|im_end|> \n "
" <|im_start|>assistant \n "
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-03-01 14:49:15 +08:00
# MiniCPM-O
2025-03-17 18:00:17 +08:00
def run_minicpmo ( question : str , audio_count : int ) - > ModelRequestData :
2025-03-01 14:49:15 +08:00
model_name = " openbmb/MiniCPM-o-2_6 "
2025-05-26 17:57:54 +01:00
tokenizer = AutoTokenizer . from_pretrained ( model_name , trust_remote_code = True )
2025-03-17 18:00:17 +08:00
engine_args = EngineArgs (
model = model_name ,
trust_remote_code = True ,
max_model_len = 4096 ,
2025-04-07 08:06:27 -07:00
max_num_seqs = 2 ,
2025-03-17 18:00:17 +08:00
limit_mm_per_prompt = { " audio " : audio_count } ,
)
2024-08-21 15:49:39 -07:00
2025-05-26 17:57:54 +01:00
stop_tokens = [ " <|im_end|> " , " <|endoftext|> " ]
2025-03-01 14:49:15 +08:00
stop_token_ids = [ tokenizer . convert_tokens_to_ids ( i ) for i in stop_tokens ]
audio_placeholder = " (<audio>./</audio>) " * audio_count
audio_chat_template = " { % f or message in messages % } {{ ' <|im_start|> ' + message[ ' role ' ] + ' \n ' + message[ ' content ' ] + ' <|im_end|> ' + ' \n ' }} { % e ndfor % } { % i f add_generation_prompt % } {{ ' <|im_start|>assistant \n <|spk_bos|><|spk|><|spk_eos|><|tts_bos|> ' }} { % e ndif % } " # noqa: E501
2025-05-26 17:57:54 +01:00
messages = [ { " role " : " user " , " content " : f " { audio_placeholder } \n { question } " } ]
prompt = tokenizer . apply_chat_template (
messages ,
tokenize = False ,
add_generation_prompt = True ,
chat_template = audio_chat_template ,
)
2025-03-17 18:00:17 +08:00
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
stop_token_ids = stop_token_ids ,
)
2024-08-21 15:49:39 -07:00
2025-03-08 01:28:52 +08:00
# Phi-4-multimodal-instruct
2025-03-17 18:00:17 +08:00
def run_phi4mm ( question : str , audio_count : int ) - > ModelRequestData :
2025-03-08 01:28:52 +08:00
"""
Phi - 4 - multimodal - instruct supports both image and audio inputs . Here , we
show how to process audio inputs .
"""
model_path = snapshot_download ( " microsoft/Phi-4-multimodal-instruct " )
# Since the vision-lora and speech-lora co-exist with the base model,
# we have to manually specify the path of the lora weights.
speech_lora_path = os . path . join ( model_path , " speech-lora " )
2025-05-26 17:57:54 +01:00
placeholders = " " . join ( [ f " <|audio_ { i + 1 } |> " for i in range ( audio_count ) ] )
2025-03-08 01:28:52 +08:00
2025-03-17 18:00:17 +08:00
prompts = f " <|user|> { placeholders } { question } <|end|><|assistant|> "
2025-03-08 01:28:52 +08:00
2025-03-17 18:00:17 +08:00
engine_args = EngineArgs (
2025-03-08 01:28:52 +08:00
model = model_path ,
trust_remote_code = True ,
2025-04-19 17:26:11 +08:00
max_model_len = 12800 ,
2025-03-08 01:28:52 +08:00
max_num_seqs = 2 ,
enable_lora = True ,
max_lora_rank = 320 ,
2025-03-08 21:57:14 +08:00
limit_mm_per_prompt = { " audio " : audio_count } ,
2025-03-08 01:28:52 +08:00
)
2025-03-17 18:00:17 +08:00
return ModelRequestData (
engine_args = engine_args ,
prompt = prompts ,
lora_requests = [ LoRARequest ( " speech " , 1 , speech_lora_path ) ] ,
)
2025-03-08 01:28:52 +08:00
2024-10-24 01:54:22 +08:00
# Qwen2-Audio
2025-03-17 18:00:17 +08:00
def run_qwen2_audio ( question : str , audio_count : int ) - > ModelRequestData :
2024-10-24 01:54:22 +08:00
model_name = " Qwen/Qwen2-Audio-7B-Instruct "
2025-03-17 18:00:17 +08:00
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 5 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
2024-10-24 01:54:22 +08:00
2025-05-26 17:57:54 +01:00
audio_in_prompt = " " . join (
[
f " Audio { idx + 1 } : <|audio_bos|><|AUDIO|><|audio_eos|> \n "
for idx in range ( audio_count )
]
)
2024-10-24 01:54:22 +08:00
2025-05-26 17:57:54 +01:00
prompt = (
" <|im_start|>system \n You are a helpful assistant.<|im_end|> \n "
" <|im_start|>user \n "
f " { audio_in_prompt } { question } <|im_end|> \n "
" <|im_start|>assistant \n "
)
2025-03-17 18:00:17 +08:00
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2024-10-24 01:54:22 +08:00
2025-04-19 14:14:36 +08:00
# Qwen2.5-Omni
def run_qwen2_5_omni ( question : str , audio_count : int ) :
model_name = " Qwen/Qwen2.5-Omni-7B "
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 5 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
2025-05-26 17:57:54 +01:00
audio_in_prompt = " " . join (
[ " <|audio_bos|><|AUDIO|><|audio_eos|> \n " for idx in range ( audio_count ) ]
)
2025-04-19 14:14:36 +08:00
default_system = (
" You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
" Group, capable of perceiving auditory and visual inputs, as well as "
2025-05-26 17:57:54 +01:00
" generating text and speech. "
)
2025-04-19 14:14:36 +08:00
2025-05-26 17:57:54 +01:00
prompt = (
f " <|im_start|>system \n { default_system } <|im_end|> \n "
" <|im_start|>user \n "
f " { audio_in_prompt } { question } <|im_end|> \n "
" <|im_start|>assistant \n "
)
2025-04-19 14:14:36 +08:00
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2026-01-29 03:27:15 -08:00
def run_qwen3_asr ( question : str , audio_count : int ) - > ModelRequestData :
model_name = " Qwen/Qwen3-Asr-1.7B "
audio_in_prompt = " <|audio_start|><|audio_pad|><|audio_end|> \n " * audio_count
prompt = f " <|im_start|>user \n { audio_in_prompt } <|im_end|> \n <|im_start|>assistant \n "
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 5 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-03-01 14:49:15 +08:00
# Ultravox 0.5-1B
2025-03-17 18:00:17 +08:00
def run_ultravox ( question : str , audio_count : int ) - > ModelRequestData :
2025-03-01 14:49:15 +08:00
model_name = " fixie-ai/ultravox-v0_5-llama-3_2-1b "
2025-01-29 17:24:59 +08:00
2025-03-01 14:49:15 +08:00
tokenizer = AutoTokenizer . from_pretrained ( model_name )
2025-05-26 17:57:54 +01:00
messages = [ { " role " : " user " , " content " : " <|audio|> \n " * audio_count + question } ]
prompt = tokenizer . apply_chat_template (
messages , tokenize = False , add_generation_prompt = True
)
2025-03-01 14:49:15 +08:00
2025-03-17 18:00:17 +08:00
engine_args = EngineArgs (
model = model_name ,
max_model_len = 4096 ,
max_num_seqs = 5 ,
trust_remote_code = True ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-03-01 14:49:15 +08:00
2025-12-14 05:14:55 -05:00
# Voxtral
# Make sure to install mistral-common[audio].
def run_voxtral ( question : str , audio_count : int ) - > ModelRequestData :
from mistral_common . audio import Audio
from mistral_common . protocol . instruct . chunk import (
AudioChunk ,
RawAudio ,
TextChunk ,
)
from mistral_common . protocol . instruct . messages import (
UserMessage ,
)
from mistral_common . protocol . instruct . request import ChatCompletionRequest
from mistral_common . tokens . tokenizers . mistral import MistralTokenizer
model_name = " mistralai/Voxtral-Mini-3B-2507 "
tokenizer = MistralTokenizer . from_hf_hub ( model_name )
engine_args = EngineArgs (
model = model_name ,
max_model_len = 8192 ,
max_num_seqs = 2 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
config_format = " mistral " ,
load_format = " mistral " ,
tokenizer_mode = " mistral " ,
enforce_eager = True ,
enable_chunked_prefill = False ,
)
text_chunk = TextChunk ( text = question )
audios = [
Audio . from_file ( str ( audio_assets [ i ] . get_local_path ( ) ) , strict = False )
for i in range ( audio_count )
]
audio_chunks = [
AudioChunk ( input_audio = RawAudio . from_audio ( audio ) ) for audio in audios
]
messages = [ UserMessage ( content = [ * audio_chunks , text_chunk ] ) ]
req = ChatCompletionRequest ( messages = messages , model = model_name )
tokens = tokenizer . encode_chat_completion ( req )
prompt_ids , audios = tokens . tokens , tokens . audios
audios_and_sr = [ ( au . audio_array , au . sampling_rate ) for au in audios ]
multi_modal_data = { " audio " : audios_and_sr }
return ModelRequestData (
engine_args = engine_args ,
prompt_token_ids = prompt_ids ,
multi_modal_data = multi_modal_data ,
)
2025-03-01 14:49:15 +08:00
# Whisper
2025-03-17 18:00:17 +08:00
def run_whisper ( question : str , audio_count : int ) - > ModelRequestData :
2025-05-26 17:57:54 +01:00
assert audio_count == 1 , " Whisper only support single audio input per prompt "
2025-03-01 14:49:15 +08:00
model_name = " openai/whisper-large-v3-turbo "
prompt = " <|startoftranscript|> "
2025-03-17 18:00:17 +08:00
engine_args = EngineArgs (
model = model_name ,
max_model_len = 448 ,
max_num_seqs = 5 ,
limit_mm_per_prompt = { " audio " : audio_count } ,
)
return ModelRequestData (
engine_args = engine_args ,
prompt = prompt ,
)
2025-01-29 17:24:59 +08:00
model_example_map = {
2025-12-14 05:14:55 -05:00
" audioflamingo3 " : run_audioflamingo3 ,
2026-03-17 17:04:17 -04:00
" cohere_asr " : run_cohere_asr ,
" funaudiochat " : run_funaudiochat ,
2025-08-09 18:56:25 +02:00
" gemma3n " : run_gemma3n ,
2025-12-31 10:12:24 -05:00
" glmasr " : run_glmasr ,
2025-04-28 04:05:00 -06:00
" granite_speech " : run_granite_speech ,
2026-03-11 12:24:48 +08:00
" kimi_audio " : run_kimi_audio ,
2025-09-04 15:08:09 +08:00
" midashenglm " : run_midashenglm ,
2025-03-01 14:49:15 +08:00
" minicpmo " : run_minicpmo ,
2026-03-17 17:04:17 -04:00
" musicflamingo " : run_musicflamingo ,
2025-03-08 01:28:52 +08:00
" phi4_mm " : run_phi4mm ,
2025-01-29 17:24:59 +08:00
" qwen2_audio " : run_qwen2_audio ,
2025-04-19 14:14:36 +08:00
" qwen2_5_omni " : run_qwen2_5_omni ,
2026-01-29 03:27:15 -08:00
" qwen3_asr " : run_qwen3_asr ,
2025-03-01 14:49:15 +08:00
" ultravox " : run_ultravox ,
2025-12-14 05:14:55 -05:00
" voxtral " : run_voxtral ,
2025-03-01 14:49:15 +08:00
" whisper " : run_whisper ,
2025-01-29 17:24:59 +08:00
}
2024-08-21 15:49:39 -07:00
2025-04-15 16:05:30 +08:00
def parse_args ( ) :
parser = FlexibleArgumentParser (
2025-05-26 17:57:54 +01:00
description = " Demo on using vLLM for offline inference with "
" audio language models "
)
parser . add_argument (
" --model-type " ,
" -m " ,
type = str ,
default = " ultravox " ,
choices = model_example_map . keys ( ) ,
help = ' Huggingface " model_type " . ' ,
)
2026-01-28 13:18:09 +08:00
parser . add_argument (
" --model " ,
type = str ,
default = None ,
help = " Model ID or local path override. Required for funaudiochat. " ,
)
2025-05-26 17:57:54 +01:00
parser . add_argument (
" --num-prompts " , type = int , default = 1 , help = " Number of prompts to run. "
)
parser . add_argument (
" --num-audios " ,
type = int ,
default = 1 ,
choices = [ 0 , 1 , 2 ] ,
help = " Number of audio items per prompt. " ,
)
parser . add_argument (
" --seed " ,
type = int ,
2025-12-11 11:59:39 +08:00
default = 0 ,
2025-05-26 17:57:54 +01:00
help = " Set the seed when initializing `vllm.LLM`. " ,
)
2025-11-25 14:03:20 +08:00
parser . add_argument (
" --tensor-parallel-size " ,
" -tp " ,
type = int ,
default = None ,
help = " Tensor parallel size to override the model ' s default setting. " ,
)
2025-04-15 16:05:30 +08:00
return parser . parse_args ( )
2024-08-21 15:49:39 -07:00
def main ( args ) :
model = args . model_type
if model not in model_example_map :
raise ValueError ( f " Model type { model } is not supported. " )
2026-01-28 13:18:09 +08:00
if model == " funaudiochat " and not args . model :
raise ValueError ( " --model is required when --model-type=funaudiochat " )
2025-11-25 14:03:20 +08:00
if args . tensor_parallel_size is not None and args . tensor_parallel_size < 1 :
raise ValueError (
f " tensor_parallel_size must be a positive integer, "
f " got { args . tensor_parallel_size } "
)
2024-09-03 21:38:21 -07:00
audio_count = args . num_audios
2025-05-26 17:57:54 +01:00
req_data = model_example_map [ model ] (
question_per_audio_count [ audio_count ] , audio_count
)
2026-01-28 13:18:09 +08:00
if model == " funaudiochat " :
req_data . engine_args . model = args . model
2025-03-17 18:00:17 +08:00
2025-04-12 16:52:39 +08:00
# Disable other modalities to save memory
default_limits = { " image " : 0 , " video " : 0 , " audio " : 0 }
req_data . engine_args . limit_mm_per_prompt = default_limits | dict (
2025-05-26 17:57:54 +01:00
req_data . engine_args . limit_mm_per_prompt or { }
)
2025-04-12 16:52:39 +08:00
2026-03-25 13:14:43 +00:00
engine_args = vars ( req_data . engine_args ) | { " seed " : args . seed }
2025-11-25 14:03:20 +08:00
if args . tensor_parallel_size is not None :
engine_args [ " tensor_parallel_size " ] = args . tensor_parallel_size
2025-03-17 18:00:17 +08:00
llm = LLM ( * * engine_args )
2024-08-21 15:49:39 -07:00
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
2025-05-26 17:57:54 +01:00
sampling_params = SamplingParams (
temperature = 0.2 , max_tokens = 64 , stop_token_ids = req_data . stop_token_ids
)
2024-08-21 15:49:39 -07:00
2025-12-31 18:34:18 -05:00
def get_input ( start , end ) :
mm_data = req_data . multi_modal_data
if not mm_data :
mm_data = { }
if end - start > 0 :
mm_data = {
" audio " : [
asset . audio_and_sample_rate for asset in audio_assets [ start : end ]
]
}
inputs = { " multi_modal_data " : mm_data }
if req_data . prompt :
inputs [ " prompt " ] = req_data . prompt
else :
inputs [ " prompt_token_ids " ] = req_data . prompt_token_ids
return inputs
# Batch inference
2024-10-24 01:54:22 +08:00
assert args . num_prompts > 0
2025-12-31 18:34:18 -05:00
if audio_count != 1 :
inputs = get_input ( 0 , audio_count )
inputs = [ inputs ] * args . num_prompts
2025-07-15 16:35:30 +02:00
else :
2025-12-31 18:34:18 -05:00
# For single audio input, we need to vary the audio input
# to avoid deduplication in vLLM engine.
inputs = [ ]
for i in range ( args . num_prompts ) :
start = i % len ( audio_assets )
inp = get_input ( start , start + 1 )
inputs . append ( inp )
2025-07-15 16:35:30 +02:00
2025-04-11 12:57:16 +08:00
# Add LoRA request if applicable
2025-05-26 17:57:54 +01:00
lora_request = (
req_data . lora_requests * args . num_prompts if req_data . lora_requests else None
)
2025-04-11 12:57:16 +08:00
outputs = llm . generate (
inputs ,
sampling_params = sampling_params ,
lora_request = lora_request ,
)
2024-08-21 15:49:39 -07:00
for o in outputs :
generated_text = o . outputs [ 0 ] . text
print ( generated_text )
if __name__ == " __main__ " :
2025-04-15 16:05:30 +08:00
args = parse_args ( )
2024-08-21 15:49:39 -07:00
main ( args )