[Multimodal][Qwen3 Omni] Make Qwen3 Omni work with audio-in-video inputs in V1 engine. (#27721)

Signed-off-by: Chenheli Hua <huachenheli@outlook.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Chenheli Hua
2025-11-24 11:24:37 -08:00
committed by GitHub
parent 8f066146c3
commit 839c6b7b72
4 changed files with 467 additions and 59 deletions

View File

@@ -0,0 +1,170 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""
from typing import NamedTuple
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode
from vllm.utils.argparse_utils import FlexibleArgumentParser
class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult:
question = (
"What is recited in the audio? "
"What is the content of this image? Why is this video funny?"
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image": convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB"
),
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
)
def get_use_audio_in_video_query() -> QueryResult:
question = (
"Describe the content of the video in details, then convert what the "
"baby say into text."
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={"audio": 1, "video": 1},
)
def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
"<|audio_start|><|audio_pad|><|audio_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
}
def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
query_result = query_map[args.query_type]()
llm = LLM(
model=model_name,
max_model_len=12800,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
def parse_args():
parser = FlexibleArgumentParser(
description="Demo on using vLLM for offline inference with "
"audio language models"
)
parser.add_argument(
"--query-type",
"-q",
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help="Query type.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,221 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import Mock
import pytest
from transformers import PretrainedConfig
from vllm.multimodal.processing import InputProcessingContext
# Helper function to print input IDs with coalesced audio/video tokens.
def print_input_ids(input_ids):
"""
Print input IDs, compressing consecutive special tokens.
- 151675: <|audio_pad|>
- 151656: <|video_pad|>
"""
if not input_ids:
print("[]")
return
result = []
i = 0
while i < len(input_ids):
current_id = input_ids[i]
# Check if it's a special token that should be compressed
if current_id in [151675, 151656]:
# Count consecutive occurrences
count = 1
while i + count < len(input_ids) and input_ids[i + count] == current_id:
count += 1
# Add compressed representation
token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
result.append(f"{token_name} * {count}")
i += count
else:
# Regular token, just add it
result.append(str(current_id))
i += 1
print(", ".join(result))
@pytest.fixture
def mock_qwen3_omni_config():
"""Create a mock Qwen3OmniMoeThinker config."""
config = Mock(spec=PretrainedConfig)
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
config.audio_token_id = 151675 # <|audio_pad|>
config.video_token_id = 151656 # <|video_pad|>
config.image_token_id = 151655 # <|image_pad|>
config.audio_start_token_id = 151669 # <|audio_start|>
config.audio_end_token_id = 151670 # <|audio_end|>
config.vision_start_token_id = 151652 # <|vision_start|>
config.position_id_per_seconds = 12.5
# Vision config
vision_config = Mock()
vision_config.spatial_merge_size = 2
config.vision_config = vision_config
return config
@pytest.fixture
def mock_processor():
"""Create a mock HF processor."""
from transformers.models.whisper import WhisperFeatureExtractor
processor = Mock()
processor.audio_token = "<|audio_pad|>"
processor.image_token = "<|image_pad|>"
processor.video_token = "<|video_pad|>"
# Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
feature_extractor = WhisperFeatureExtractor()
processor.feature_extractor = feature_extractor
return processor
@pytest.fixture
def mock_tokenizer():
"""Create a mock tokenizer."""
tokenizer = Mock()
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
tokenizer.get_vocab = Mock(
return_value={
"<|audio_pad|>": 151675,
"<|video_pad|>": 151656,
"<|image_pad|>": 151655,
"<|audio_start|>": 151669,
"<|audio_end|>": 151670,
"<|vision_start|>": 151652,
"<|vision_end|>": 151653,
}
)
tokenizer.encode = Mock(
side_effect=lambda x: {
"<|vision_start|>": [151652],
"<|vision_end|>": [151653],
"<|audio_start|>": [151669],
"<|audio_end|>": [151670],
"<|audio_pad|>": [151675],
"<|image_pad|>": [151655],
"<|video_pad|>": [151656],
}.get(x, [0])
)
tokenizer.vision_bos_token = "<|vision_start|>"
tokenizer.vision_eos_token = "<|vision_end|>"
tokenizer.audio_bos_token = "<|audio_start|>"
tokenizer.audio_eos_token = "<|audio_end|>"
return tokenizer
@pytest.fixture
def mock_image_processor():
"""Create a mock image processor."""
image_processor = Mock()
image_processor.merge_size = 2
return image_processor
def test_qwen3_omni_get_updates_use_audio_in_video(
mock_qwen3_omni_config,
mock_processor,
mock_tokenizer,
mock_image_processor,
):
"""Test the get_updates_use_audio_in_video method directly."""
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
Qwen3OmniMoeThinkerMultiModalProcessor,
Qwen3OmniMoeThinkerProcessingInfo,
)
# Create a mock context
mock_ctx = Mock(spec=InputProcessingContext)
# Create processing info
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
info.get_hf_processor = Mock(return_value=mock_processor)
info.get_tokenizer = Mock(return_value=mock_tokenizer)
info.get_image_processor = Mock(return_value=mock_image_processor)
# Create a mock dummy_inputs builder
mock_dummy_inputs = Mock()
# Create the processor
processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
# Test parameters from reference video
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
audio_len = 85
video_grid_thw = [6, 36, 64]
video_second_per_grid_t = 2.0
# Call the method
updates = processor.get_updates_use_audio_in_video(
thinker_config=mock_qwen3_omni_config,
audio_len=audio_len,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
# Updated input ids should align with HF implementation.
# 151669,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 25,
# <|video_pad|> * 576, <|audio_pad|> * 10,
# <|video_pad|> * 1152,
# 151670
print_input_ids(updates)
# Verify structure
assert isinstance(updates, list)
assert len(updates) > 0
# Verify start and end tokens
audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
assert updates[0] == audio_start_token_id
assert updates[-1] == audio_end_token_id
# Verify both audio and video tokens are present
audio_token_id = mock_qwen3_omni_config.audio_token_id
video_token_id = mock_qwen3_omni_config.video_token_id
audio_count = updates.count(audio_token_id)
video_count = updates.count(video_token_id)
assert audio_count == audio_len, (
f"Expected {audio_len} audio tokens, got {audio_count}"
)
# Calculate expected video token count
spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
height = video_grid_thw[1] // spatial_merge_size
width = video_grid_thw[2] // spatial_merge_size
expected_video_count = video_grid_thw[0] * height * width
assert video_count == expected_video_count, (
f"Expected {expected_video_count} video tokens, got {video_count}"
)
# Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
expected_total = 1 + audio_len + expected_video_count + 1
assert len(updates) == expected_total, (
f"Expected {expected_total} total tokens, got {len(updates)}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -23,7 +23,6 @@
"""Inference-only Qwen2.5-Omni model (thinker part).""" """Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from copy import copy
from functools import partial from functools import partial
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
@@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts) self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts) self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
use_audio_in_video = False
if "video" in mm_kwargs:
video_items = [item for item in mm_kwargs["video"] if item is not None]
# only check video items (if there are any)
if video_items:
use_audio_in_video = all(
item["use_audio_in_video"].data for item in video_items
)
if is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
prompt_ids, prompt_ids,
@@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
mm_item_counts, mm_item_counts,
use_audio_in_video=use_audio_in_video,
) )
else: else:
prompt_ids, mm_placeholders = self._apply_prompt_updates( prompt_ids, mm_placeholders = self._apply_prompt_updates(
@@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
mm_item_counts, mm_item_counts,
use_audio_in_video=use_audio_in_video,
) )
return prompt_ids, mm_placeholders return prompt_ids, mm_placeholders
@@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return mm_processed_data return mm_processed_data
def _validate_mm_placeholders(
self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int],
use_audio_in_video: bool = False,
) -> None:
if use_audio_in_video:
mm_item_counts = copy(mm_item_counts)
if "video" in mm_item_counts:
assert "audio" in mm_item_counts
mm_item_counts["audio"] -= mm_item_counts["video"]
super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
class Qwen2_5OmniConditionalGenerationMixin: class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(

View File

@@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import ( from vllm.multimodal.processing import (
BaseMultiModalProcessor,
MultiModalPromptUpdates, MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
PromptUpdateDetails,
) )
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
@@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
Qwen2_5OmniConditionalGenerationMixin, Qwen2_5OmniConditionalGenerationMixin,
Qwen2_5OmniThinkerDummyInputsBuilder, Qwen2_5OmniThinkerDummyInputsBuilder,
Qwen2_5OmniThinkerMultiModalProcessor, Qwen2_5OmniThinkerMultiModalProcessor,
Qwen2_5OmniThinkerProcessingInfo,
) )
from .qwen2_5_vl import ( from .qwen2_5_vl import (
Qwen2_5_VisionAttention, Qwen2_5_VisionAttention,
@@ -807,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
else: else:
use_audio_in_video = False use_audio_in_video = False
if use_audio_in_video and "video" in mm_item_counts:
assert "audio" in mm_item_counts
mm_item_counts["audio"] -= mm_item_counts["video"]
# Special case with `use_audio_in_video=True`
if use_audio_in_video:
if is_update_applied:
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
(
prompt_ids,
mm_placeholders,
) = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
# normal case with `use_audio_in_video=False` # normal case with `use_audio_in_video=False`
elif is_update_applied: if is_update_applied:
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
prompt_ids, prompt_ids,
mm_prompt_updates, mm_prompt_updates,
@@ -834,10 +817,24 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
mm_item_counts, mm_item_counts,
) )
else: else:
prompt_ids, mm_placeholders = self._apply_prompt_updates( if use_audio_in_video and "audio" in mm_prompt_updates:
prompt_ids, filtered_updates = {
mm_prompt_updates, k: v for k, v in mm_prompt_updates.items() if k != "audio"
) }
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
filtered_updates,
)
# Derive audio placeholders from video placeholders
mm_placeholders = self._derive_audio_from_video_placeholders(
mm_placeholders, mm_prompt_updates
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders( self._validate_mm_placeholders(
mm_placeholders, mm_placeholders,
mm_item_counts, mm_item_counts,
@@ -962,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
def get_replacement_qwen2_use_audio_in_video(item_idx: int): def get_replacement_qwen2_use_audio_in_video(item_idx: int):
nonlocal audio_in_video_item_idx nonlocal audio_in_video_item_idx
audio_num_features = audio_output_lengths[audio_item_idx + item_idx] audio_num_features = audio_output_lengths[
audio_in_video_item_idx + item_idx
]
video_grid_thw = out_mm_data["video_grid_thw"][item_idx] video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
audio_in_video_item_idx += 1 audio_in_video_item_idx += 1
@@ -971,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
if second_per_grid_ts: if second_per_grid_ts:
video_second_per_grid_t = second_per_grid_ts[item_idx] video_second_per_grid_t = second_per_grid_ts[item_idx]
else: else:
video_second_per_grid_t = 1.0 video_second_per_grid_t = 2.0
return self.get_updates_use_audio_in_video( placeholder = self.get_updates_use_audio_in_video(
thinker_config=thinker_config, thinker_config=thinker_config,
audio_len=audio_num_features, audio_len=audio_num_features,
video_grid_thw=video_grid_thw, video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t, video_second_per_grid_t=video_second_per_grid_t,
) )
return PromptUpdateDetails.select_token_id(
placeholder, embed_token_id=video_token_id
)
video_replacement_fn = ( video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video get_replacement_qwen2_use_audio_in_video
@@ -1004,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
), ),
] ]
def _validate_mm_placeholders( def _derive_audio_from_video_placeholders(
self, self,
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]], placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_item_counts: Mapping[str, int], mm_prompt_updates: MultiModalPromptUpdates,
) -> None: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
BaseMultiModalProcessor[ """
Qwen2_5OmniThinkerProcessingInfo Helper to derive audio placeholders from video placeholders when
]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts) use_audio_in_video=True.
"""
if "video" not in placeholders:
return placeholders
# Validate audio and video counts match
num_videos = len(placeholders["video"])
num_audios = len(mm_prompt_updates.get("audio", []))
if num_audios != num_videos:
raise ValueError(
f"use_audio_in_video requires equal number of audio and video items, "
f"got {num_audios=}, {num_videos=}"
)
tokenizer = self.info.get_tokenizer()
processor = self.info.get_hf_processor()
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
result_placeholders = dict(placeholders)
audio_placeholders = []
# Each video is paired with one audio
for video_idx, video_placeholder in enumerate(placeholders["video"]):
# Create is_embed mask selecting only audio tokens
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
audio_placeholder = PlaceholderFeaturesInfo(
modality="audio",
item_idx=video_idx,
start_idx=video_placeholder.start_idx,
tokens=video_placeholder.tokens,
is_embed=audio_is_embed,
)
audio_placeholders.append(audio_placeholder)
result_placeholders["audio"] = audio_placeholders
return result_placeholders
def _get_raw_input_ids( def _get_raw_input_ids(
self, self,
@@ -1454,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
) )
if not len(second_per_grid_ts) and len(video_grid_thw): if not len(second_per_grid_ts) and len(video_grid_thw):
second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32) second_per_grid_ts = 2.0
second_per_grids = (
torch.ones(len(video_grid_thw), dtype=torch.float32)
* second_per_grid_ts
)
else: else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32) second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)