[Multimodal][Qwen3 Omni] Make Qwen3 Omni work with audio-in-video inputs in V1 engine. (#27721)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
170
examples/offline_inference/qwen3_omni/only_thinker.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
This example shows how to use vLLM for running offline inference
|
||||||
|
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.assets.video import VideoAsset
|
||||||
|
from vllm.multimodal.image import convert_image_mode
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(NamedTuple):
|
||||||
|
inputs: dict
|
||||||
|
limit_mm_per_prompt: dict[str, int]
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||||
|
# lower-end GPUs.
|
||||||
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
default_system = (
|
||||||
|
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||||
|
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||||
|
"generating text and speech."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mixed_modalities_query() -> QueryResult:
|
||||||
|
question = (
|
||||||
|
"What is recited in the audio? "
|
||||||
|
"What is the content of this image? Why is this video funny?"
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||||
|
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
"<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||||
|
"image": convert_image_mode(
|
||||||
|
ImageAsset("cherry_blossom").pil_image, "RGB"
|
||||||
|
),
|
||||||
|
"video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_use_audio_in_video_query() -> QueryResult:
|
||||||
|
question = (
|
||||||
|
"Describe the content of the video in details, then convert what the "
|
||||||
|
"baby say into text."
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
asset = VideoAsset(name="baby_reading", num_frames=16)
|
||||||
|
audio = asset.get_audio(sampling_rate=16000)
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"video": asset.np_ndarrays,
|
||||||
|
"audio": audio,
|
||||||
|
},
|
||||||
|
"mm_processor_kwargs": {
|
||||||
|
"use_audio_in_video": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={"audio": 1, "video": 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_audios_query() -> QueryResult:
|
||||||
|
question = "Are these two audio clips the same?"
|
||||||
|
prompt = (
|
||||||
|
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|>"
|
||||||
|
"<|audio_start|><|audio_pad|><|audio_end|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": [
|
||||||
|
AudioAsset("winning_call").audio_and_sample_rate,
|
||||||
|
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_map = {
|
||||||
|
"mixed_modalities": get_mixed_modalities_query,
|
||||||
|
"use_audio_in_video": get_use_audio_in_video_query,
|
||||||
|
"multi_audios": get_multi_audios_query,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
||||||
|
query_result = query_map[args.query_type]()
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=12800,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
|
# even when all prompts are identical when running batch inference.
|
||||||
|
sampling_params = SamplingParams(temperature=0.2, max_tokens=256)
|
||||||
|
|
||||||
|
outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Demo on using vLLM for offline inference with "
|
||||||
|
"audio language models"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--query-type",
|
||||||
|
"-q",
|
||||||
|
type=str,
|
||||||
|
default="mixed_modalities",
|
||||||
|
choices=query_map.keys(),
|
||||||
|
help="Query type.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the seed when initializing `vllm.LLM`.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parse_args()
|
||||||
|
main(args)
|
||||||
221
tests/model_executor/test_qwen3_omni.py
Normal file
221
tests/model_executor/test_qwen3_omni.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.multimodal.processing import InputProcessingContext
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function to print input IDs with coalesced audio/video tokens.
|
||||||
|
def print_input_ids(input_ids):
|
||||||
|
"""
|
||||||
|
Print input IDs, compressing consecutive special tokens.
|
||||||
|
- 151675: <|audio_pad|>
|
||||||
|
- 151656: <|video_pad|>
|
||||||
|
"""
|
||||||
|
if not input_ids:
|
||||||
|
print("[]")
|
||||||
|
return
|
||||||
|
|
||||||
|
result = []
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(input_ids):
|
||||||
|
current_id = input_ids[i]
|
||||||
|
|
||||||
|
# Check if it's a special token that should be compressed
|
||||||
|
if current_id in [151675, 151656]:
|
||||||
|
# Count consecutive occurrences
|
||||||
|
count = 1
|
||||||
|
while i + count < len(input_ids) and input_ids[i + count] == current_id:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# Add compressed representation
|
||||||
|
token_name = "<|audio_pad|>" if current_id == 151675 else "<|video_pad|>"
|
||||||
|
result.append(f"{token_name} * {count}")
|
||||||
|
i += count
|
||||||
|
else:
|
||||||
|
# Regular token, just add it
|
||||||
|
result.append(str(current_id))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
print(", ".join(result))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qwen3_omni_config():
|
||||||
|
"""Create a mock Qwen3OmniMoeThinker config."""
|
||||||
|
config = Mock(spec=PretrainedConfig)
|
||||||
|
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
||||||
|
config.audio_token_id = 151675 # <|audio_pad|>
|
||||||
|
config.video_token_id = 151656 # <|video_pad|>
|
||||||
|
config.image_token_id = 151655 # <|image_pad|>
|
||||||
|
config.audio_start_token_id = 151669 # <|audio_start|>
|
||||||
|
config.audio_end_token_id = 151670 # <|audio_end|>
|
||||||
|
config.vision_start_token_id = 151652 # <|vision_start|>
|
||||||
|
config.position_id_per_seconds = 12.5
|
||||||
|
|
||||||
|
# Vision config
|
||||||
|
vision_config = Mock()
|
||||||
|
vision_config.spatial_merge_size = 2
|
||||||
|
config.vision_config = vision_config
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_processor():
|
||||||
|
"""Create a mock HF processor."""
|
||||||
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
|
|
||||||
|
processor = Mock()
|
||||||
|
processor.audio_token = "<|audio_pad|>"
|
||||||
|
processor.image_token = "<|image_pad|>"
|
||||||
|
processor.video_token = "<|video_pad|>"
|
||||||
|
|
||||||
|
# Create a real WhisperFeatureExtractor instance for the feature_extractor attribute
|
||||||
|
feature_extractor = WhisperFeatureExtractor()
|
||||||
|
processor.feature_extractor = feature_extractor
|
||||||
|
|
||||||
|
return processor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tokenizer():
|
||||||
|
"""Create a mock tokenizer."""
|
||||||
|
tokenizer = Mock()
|
||||||
|
# Token IDs from https://huggingface.co/Qwen/Qwen3-Omni-30B-A3B-Instruct/blob/main/tokenizer_config.json
|
||||||
|
tokenizer.get_vocab = Mock(
|
||||||
|
return_value={
|
||||||
|
"<|audio_pad|>": 151675,
|
||||||
|
"<|video_pad|>": 151656,
|
||||||
|
"<|image_pad|>": 151655,
|
||||||
|
"<|audio_start|>": 151669,
|
||||||
|
"<|audio_end|>": 151670,
|
||||||
|
"<|vision_start|>": 151652,
|
||||||
|
"<|vision_end|>": 151653,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer.encode = Mock(
|
||||||
|
side_effect=lambda x: {
|
||||||
|
"<|vision_start|>": [151652],
|
||||||
|
"<|vision_end|>": [151653],
|
||||||
|
"<|audio_start|>": [151669],
|
||||||
|
"<|audio_end|>": [151670],
|
||||||
|
"<|audio_pad|>": [151675],
|
||||||
|
"<|image_pad|>": [151655],
|
||||||
|
"<|video_pad|>": [151656],
|
||||||
|
}.get(x, [0])
|
||||||
|
)
|
||||||
|
tokenizer.vision_bos_token = "<|vision_start|>"
|
||||||
|
tokenizer.vision_eos_token = "<|vision_end|>"
|
||||||
|
tokenizer.audio_bos_token = "<|audio_start|>"
|
||||||
|
tokenizer.audio_eos_token = "<|audio_end|>"
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_image_processor():
|
||||||
|
"""Create a mock image processor."""
|
||||||
|
image_processor = Mock()
|
||||||
|
image_processor.merge_size = 2
|
||||||
|
return image_processor
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_omni_get_updates_use_audio_in_video(
|
||||||
|
mock_qwen3_omni_config,
|
||||||
|
mock_processor,
|
||||||
|
mock_tokenizer,
|
||||||
|
mock_image_processor,
|
||||||
|
):
|
||||||
|
"""Test the get_updates_use_audio_in_video method directly."""
|
||||||
|
|
||||||
|
from vllm.model_executor.models.qwen3_omni_moe_thinker import (
|
||||||
|
Qwen3OmniMoeThinkerMultiModalProcessor,
|
||||||
|
Qwen3OmniMoeThinkerProcessingInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock context
|
||||||
|
mock_ctx = Mock(spec=InputProcessingContext)
|
||||||
|
|
||||||
|
# Create processing info
|
||||||
|
info = Qwen3OmniMoeThinkerProcessingInfo(mock_ctx)
|
||||||
|
info.get_hf_config = Mock(return_value=mock_qwen3_omni_config)
|
||||||
|
info.get_hf_processor = Mock(return_value=mock_processor)
|
||||||
|
info.get_tokenizer = Mock(return_value=mock_tokenizer)
|
||||||
|
info.get_image_processor = Mock(return_value=mock_image_processor)
|
||||||
|
|
||||||
|
# Create a mock dummy_inputs builder
|
||||||
|
mock_dummy_inputs = Mock()
|
||||||
|
|
||||||
|
# Create the processor
|
||||||
|
processor = Qwen3OmniMoeThinkerMultiModalProcessor(info, mock_dummy_inputs)
|
||||||
|
|
||||||
|
# Test parameters from reference video
|
||||||
|
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-Omni/demo/draw.mp4
|
||||||
|
audio_len = 85
|
||||||
|
video_grid_thw = [6, 36, 64]
|
||||||
|
video_second_per_grid_t = 2.0
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
updates = processor.get_updates_use_audio_in_video(
|
||||||
|
thinker_config=mock_qwen3_omni_config,
|
||||||
|
audio_len=audio_len,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
video_second_per_grid_t=video_second_per_grid_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Updated input ids should align with HF implementation.
|
||||||
|
# 151669,
|
||||||
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||||
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||||
|
# <|video_pad|> * 576, <|audio_pad|> * 25,
|
||||||
|
# <|video_pad|> * 576, <|audio_pad|> * 10,
|
||||||
|
# <|video_pad|> * 1152,
|
||||||
|
# 151670
|
||||||
|
print_input_ids(updates)
|
||||||
|
|
||||||
|
# Verify structure
|
||||||
|
assert isinstance(updates, list)
|
||||||
|
assert len(updates) > 0
|
||||||
|
|
||||||
|
# Verify start and end tokens
|
||||||
|
audio_start_token_id = mock_qwen3_omni_config.audio_start_token_id
|
||||||
|
audio_end_token_id = mock_qwen3_omni_config.audio_end_token_id
|
||||||
|
|
||||||
|
assert updates[0] == audio_start_token_id
|
||||||
|
assert updates[-1] == audio_end_token_id
|
||||||
|
|
||||||
|
# Verify both audio and video tokens are present
|
||||||
|
audio_token_id = mock_qwen3_omni_config.audio_token_id
|
||||||
|
video_token_id = mock_qwen3_omni_config.video_token_id
|
||||||
|
|
||||||
|
audio_count = updates.count(audio_token_id)
|
||||||
|
video_count = updates.count(video_token_id)
|
||||||
|
|
||||||
|
assert audio_count == audio_len, (
|
||||||
|
f"Expected {audio_len} audio tokens, got {audio_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate expected video token count
|
||||||
|
spatial_merge_size = mock_qwen3_omni_config.vision_config.spatial_merge_size
|
||||||
|
height = video_grid_thw[1] // spatial_merge_size
|
||||||
|
width = video_grid_thw[2] // spatial_merge_size
|
||||||
|
expected_video_count = video_grid_thw[0] * height * width
|
||||||
|
|
||||||
|
assert video_count == expected_video_count, (
|
||||||
|
f"Expected {expected_video_count} video tokens, got {video_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Total tokens should be: 1 (start) + audio_len + video_count + 1 (end)
|
||||||
|
expected_total = 1 + audio_len + expected_video_count + 1
|
||||||
|
assert len(updates) == expected_total, (
|
||||||
|
f"Expected {expected_total} total tokens, got {len(updates)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@@ -23,7 +23,6 @@
|
|||||||
"""Inference-only Qwen2.5-Omni model (thinker part)."""
|
"""Inference-only Qwen2.5-Omni model (thinker part)."""
|
||||||
|
|
||||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||||
from copy import copy
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
@@ -387,15 +386,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||||
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
|
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
|
||||||
|
|
||||||
use_audio_in_video = False
|
|
||||||
if "video" in mm_kwargs:
|
|
||||||
video_items = [item for item in mm_kwargs["video"] if item is not None]
|
|
||||||
# only check video items (if there are any)
|
|
||||||
if video_items:
|
|
||||||
use_audio_in_video = all(
|
|
||||||
item["use_audio_in_video"].data for item in video_items
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_update_applied:
|
if is_update_applied:
|
||||||
mm_placeholders = self._find_mm_placeholders(
|
mm_placeholders = self._find_mm_placeholders(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
@@ -404,7 +394,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
self._validate_mm_placeholders(
|
self._validate_mm_placeholders(
|
||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
use_audio_in_video=use_audio_in_video,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||||
@@ -414,7 +403,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
self._validate_mm_placeholders(
|
self._validate_mm_placeholders(
|
||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
use_audio_in_video=use_audio_in_video,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, mm_placeholders
|
return prompt_ids, mm_placeholders
|
||||||
@@ -640,19 +628,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
|||||||
|
|
||||||
return mm_processed_data
|
return mm_processed_data
|
||||||
|
|
||||||
def _validate_mm_placeholders(
|
|
||||||
self,
|
|
||||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
|
||||||
mm_item_counts: Mapping[str, int],
|
|
||||||
use_audio_in_video: bool = False,
|
|
||||||
) -> None:
|
|
||||||
if use_audio_in_video:
|
|
||||||
mm_item_counts = copy(mm_item_counts)
|
|
||||||
if "video" in mm_item_counts:
|
|
||||||
assert "audio" in mm_item_counts
|
|
||||||
mm_item_counts["audio"] -= mm_item_counts["video"]
|
|
||||||
super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5OmniConditionalGenerationMixin:
|
class Qwen2_5OmniConditionalGenerationMixin:
|
||||||
def _parse_and_validate_audio_input(
|
def _parse_and_validate_audio_input(
|
||||||
|
|||||||
@@ -68,11 +68,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
|||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
|
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
|
||||||
from vllm.multimodal.processing import (
|
from vllm.multimodal.processing import (
|
||||||
BaseMultiModalProcessor,
|
|
||||||
MultiModalPromptUpdates,
|
MultiModalPromptUpdates,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement,
|
PromptReplacement,
|
||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
|
PromptUpdateDetails,
|
||||||
)
|
)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
@@ -87,7 +87,6 @@ from .qwen2_5_omni_thinker import (
|
|||||||
Qwen2_5OmniConditionalGenerationMixin,
|
Qwen2_5OmniConditionalGenerationMixin,
|
||||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||||
Qwen2_5OmniThinkerProcessingInfo,
|
|
||||||
)
|
)
|
||||||
from .qwen2_5_vl import (
|
from .qwen2_5_vl import (
|
||||||
Qwen2_5_VisionAttention,
|
Qwen2_5_VisionAttention,
|
||||||
@@ -807,24 +806,8 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
else:
|
else:
|
||||||
use_audio_in_video = False
|
use_audio_in_video = False
|
||||||
|
|
||||||
if use_audio_in_video and "video" in mm_item_counts:
|
|
||||||
assert "audio" in mm_item_counts
|
|
||||||
mm_item_counts["audio"] -= mm_item_counts["video"]
|
|
||||||
|
|
||||||
# Special case with `use_audio_in_video=True`
|
|
||||||
if use_audio_in_video:
|
|
||||||
if is_update_applied:
|
|
||||||
prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)
|
|
||||||
(
|
|
||||||
prompt_ids,
|
|
||||||
mm_placeholders,
|
|
||||||
) = self._apply_prompt_updates(
|
|
||||||
prompt_ids,
|
|
||||||
mm_prompt_updates,
|
|
||||||
)
|
|
||||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
|
||||||
# normal case with `use_audio_in_video=False`
|
# normal case with `use_audio_in_video=False`
|
||||||
elif is_update_applied:
|
if is_update_applied:
|
||||||
mm_placeholders = self._find_mm_placeholders(
|
mm_placeholders = self._find_mm_placeholders(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_prompt_updates,
|
mm_prompt_updates,
|
||||||
@@ -833,11 +816,25 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if use_audio_in_video and "audio" in mm_prompt_updates:
|
||||||
|
filtered_updates = {
|
||||||
|
k: v for k, v in mm_prompt_updates.items() if k != "audio"
|
||||||
|
}
|
||||||
|
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||||
|
prompt_ids,
|
||||||
|
filtered_updates,
|
||||||
|
)
|
||||||
|
# Derive audio placeholders from video placeholders
|
||||||
|
mm_placeholders = self._derive_audio_from_video_placeholders(
|
||||||
|
mm_placeholders, mm_prompt_updates
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
mm_prompt_updates,
|
mm_prompt_updates,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._validate_mm_placeholders(
|
self._validate_mm_placeholders(
|
||||||
mm_placeholders,
|
mm_placeholders,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
@@ -962,7 +959,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
|
|
||||||
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
|
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
|
||||||
nonlocal audio_in_video_item_idx
|
nonlocal audio_in_video_item_idx
|
||||||
audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
|
audio_num_features = audio_output_lengths[
|
||||||
|
audio_in_video_item_idx + item_idx
|
||||||
|
]
|
||||||
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
|
video_grid_thw = out_mm_data["video_grid_thw"][item_idx]
|
||||||
|
|
||||||
audio_in_video_item_idx += 1
|
audio_in_video_item_idx += 1
|
||||||
@@ -971,14 +970,17 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
if second_per_grid_ts:
|
if second_per_grid_ts:
|
||||||
video_second_per_grid_t = second_per_grid_ts[item_idx]
|
video_second_per_grid_t = second_per_grid_ts[item_idx]
|
||||||
else:
|
else:
|
||||||
video_second_per_grid_t = 1.0
|
video_second_per_grid_t = 2.0
|
||||||
|
|
||||||
return self.get_updates_use_audio_in_video(
|
placeholder = self.get_updates_use_audio_in_video(
|
||||||
thinker_config=thinker_config,
|
thinker_config=thinker_config,
|
||||||
audio_len=audio_num_features,
|
audio_len=audio_num_features,
|
||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
video_second_per_grid_t=video_second_per_grid_t,
|
video_second_per_grid_t=video_second_per_grid_t,
|
||||||
)
|
)
|
||||||
|
return PromptUpdateDetails.select_token_id(
|
||||||
|
placeholder, embed_token_id=video_token_id
|
||||||
|
)
|
||||||
|
|
||||||
video_replacement_fn = (
|
video_replacement_fn = (
|
||||||
get_replacement_qwen2_use_audio_in_video
|
get_replacement_qwen2_use_audio_in_video
|
||||||
@@ -1004,14 +1006,50 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _validate_mm_placeholders(
|
def _derive_audio_from_video_placeholders(
|
||||||
self,
|
self,
|
||||||
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_prompt_updates: MultiModalPromptUpdates,
|
||||||
) -> None:
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
BaseMultiModalProcessor[
|
"""
|
||||||
Qwen2_5OmniThinkerProcessingInfo
|
Helper to derive audio placeholders from video placeholders when
|
||||||
]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)
|
use_audio_in_video=True.
|
||||||
|
"""
|
||||||
|
if "video" not in placeholders:
|
||||||
|
return placeholders
|
||||||
|
|
||||||
|
# Validate audio and video counts match
|
||||||
|
num_videos = len(placeholders["video"])
|
||||||
|
num_audios = len(mm_prompt_updates.get("audio", []))
|
||||||
|
if num_audios != num_videos:
|
||||||
|
raise ValueError(
|
||||||
|
f"use_audio_in_video requires equal number of audio and video items, "
|
||||||
|
f"got {num_audios=}, {num_videos=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
processor = self.info.get_hf_processor()
|
||||||
|
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
|
||||||
|
|
||||||
|
result_placeholders = dict(placeholders)
|
||||||
|
audio_placeholders = []
|
||||||
|
|
||||||
|
# Each video is paired with one audio
|
||||||
|
for video_idx, video_placeholder in enumerate(placeholders["video"]):
|
||||||
|
# Create is_embed mask selecting only audio tokens
|
||||||
|
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
|
||||||
|
|
||||||
|
audio_placeholder = PlaceholderFeaturesInfo(
|
||||||
|
modality="audio",
|
||||||
|
item_idx=video_idx,
|
||||||
|
start_idx=video_placeholder.start_idx,
|
||||||
|
tokens=video_placeholder.tokens,
|
||||||
|
is_embed=audio_is_embed,
|
||||||
|
)
|
||||||
|
audio_placeholders.append(audio_placeholder)
|
||||||
|
|
||||||
|
result_placeholders["audio"] = audio_placeholders
|
||||||
|
return result_placeholders
|
||||||
|
|
||||||
def _get_raw_input_ids(
|
def _get_raw_input_ids(
|
||||||
self,
|
self,
|
||||||
@@ -1454,7 +1492,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not len(second_per_grid_ts) and len(video_grid_thw):
|
if not len(second_per_grid_ts) and len(video_grid_thw):
|
||||||
second_per_grids = torch.ones(len(video_grid_thw), dtype=torch.float32)
|
second_per_grid_ts = 2.0
|
||||||
|
second_per_grids = (
|
||||||
|
torch.ones(len(video_grid_thw), dtype=torch.float32)
|
||||||
|
* second_per_grid_ts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user