[Model] Use mm_position to compute mrope positions for Qwen3-Omni (#33010)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
Itay Etelis
2026-01-26 15:48:07 +02:00
committed by GitHub
parent e33192b269
commit 6ca2c91b96
2 changed files with 302 additions and 307 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
with the correct prompt format on Qwen3-Omni (thinker only).
"""
from typing import NamedTuple
@@ -112,23 +112,51 @@ def get_multi_audios_query() -> QueryResult:
)
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
"<|vision_start|><|image_pad|><|vision_end|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
}
def main(args):
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
model_name = args.model
query_result = query_map[args.query_type]()
llm = LLM(
model=model_name,
max_model_len=12800,
max_model_len=args.max_model_len,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed,
tensor_parallel_size=args.tensor_parallel_size,
gpu_memory_utilization=args.gpu_memory_utilization,
)
# We set temperature to 0.2 so that outputs can be different
@@ -161,6 +189,31 @@ def parse_args():
default=0,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
help="Model name or path.",
)
parser.add_argument(
"--tensor-parallel-size",
"-tp",
type=int,
default=1,
help="Tensor parallel size for distributed inference.",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.9,
help="GPU memory utilization (0.0 to 1.0).",
)
parser.add_argument(
"--max-model-len",
type=int,
default=12800,
help="Maximum model context length.",
)
return parser.parse_args()

View File

@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Any
@@ -104,10 +104,7 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import (
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)
@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
return loaded_weights
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
"""Compute audio tokens from feature length using Qwen3-Omni formula."""
return _get_feat_extract_output_lengths(
torch.tensor([audio_feature_length])
).item()
def _get_audio_for_video_mapping(
self, mm_features: list[MultiModalFeatureSpec]
) -> tuple[dict[int, int], set[int]]:
"""
Map video offset -> paired audio_feature_length for use_audio_in_video.
When use_audio_in_video=True, audio is interleaved within video.
The pairing is based on feature order in mm_features.
Returns:
Tuple of (video_offset -> audio_feature_length mapping,
set of paired audio offsets to skip)
"""
videos_with_audio = [
f
for f in mm_features
if f.modality == "video"
and f.data.get("use_audio_in_video")
and f.data["use_audio_in_video"].data.item()
]
audios = [f for f in mm_features if f.modality == "audio"]
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size
position_id_per_seconds = config.position_id_per_seconds
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
sorted_features
)
for mm_feature in sorted_features:
offset = mm_feature.mm_position.offset
modality = mm_feature.modality
if modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
yield (
offset,
"image",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": position_id_per_seconds,
},
)
elif modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 2.0
if mm_feature.data.get("second_per_grid_ts"):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
use_audio_in_video = bool(
mm_feature.data.get("use_audio_in_video")
and mm_feature.data["use_audio_in_video"].data.item()
)
yield (
offset,
"video",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * position_id_per_seconds,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
)
elif modality == "audio":
if offset not in paired_audio_offsets:
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
yield offset, "audio", {"audio_feature_length": audio_len}
def _compute_interleaved_positions(
self, start_idx: int, data: dict[str, Any]
) -> tuple[np.ndarray, int]:
"""
Compute positions for interleaved video+audio using Qwen3 token-by-token
interleaving logic.
Returns: (position_ids [3, N], total_token_count)
"""
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_feature_length = data["audio_feature_length"]
audio_len = self._compute_audio_token_count(audio_feature_length)
h_index = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
).flatten()
w_index = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
).flatten()
t_index_raw = np.arange(grid_t)
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
video_t_values = video_pos[0]
audio_t_values = audio_pos[0]
pos_ids_list: list[np.ndarray] = []
video_idx, audio_idx = 0, 0
num_video = grid_t * grid_h * grid_w
while video_idx < num_video and audio_idx < audio_len:
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
video_idx += 1
else:
pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
audio_idx += 1
if video_idx < num_video:
pos_ids_list.append(video_pos[:, video_idx:])
if audio_idx < audio_len:
pos_ids_list.append(audio_pos[:, audio_idx:])
total_tokens = num_video + audio_len
return np.concatenate(pos_ids_list, axis=1), total_tokens
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
"""Compute M-RoPE input positions using mm_features directly."""
seq_len = len(input_tokens)
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
input_ids = torch.tensor(input_tokens)
if input_ids is None or input_ids.ndim != 1:
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
seq_len = input_ids.shape[0]
if isinstance(audio_feature_lengths, list):
audio_feature_lengths = torch.tensor(
audio_feature_lengths, dtype=torch.long
)
if not len(second_per_grid_ts) and len(video_grid_thw):
second_per_grid_ts = 2.0
second_per_grids = (
torch.ones(len(video_grid_thw), dtype=torch.float32)
* second_per_grid_ts
)
else:
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
config = self.config
spatial_merge_size = config.vision_config.spatial_merge_size
image_token_id = config.image_token_id
video_token_id = config.video_token_id
audio_token_id = config.audio_token_id
vision_start_token_id = config.vision_start_token_id
audio_start_token_id = config.audio_start_token_id
position_id_per_seconds = config.position_id_per_seconds
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
if vision_start_indices.numel() > 0:
vision_tokens = input_ids[vision_start_indices + 1]
else:
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
audio_nums = torch.sum(input_ids == audio_start_token_id)
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (
(vision_tokens == audio_start_token_id).sum()
if use_audio_in_video
else (vision_tokens == video_token_id).sum()
)
llm_pos_ids_list: list[torch.Tensor] = []
llm_pos_ids_list: list[np.ndarray] = []
st = 0
image_idx = 0
video_idx = 0
audio_idx = 0
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
multimodal_nums = (
image_nums + audio_nums
if use_audio_in_video
else image_nums + video_nums + audio_nums
) # noqa: E501
for _ in range(multimodal_nums):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
remain_videos > 0 or remain_images > 0
):
ed_vision_start = input_tokens.index(vision_start_token_id, st)
else:
ed_vision_start = len(input_tokens) + 1
if audio_token_id in input_tokens and remain_audios > 0:
ed_audio_start = input_tokens.index(audio_start_token_id, st)
else:
ed_audio_start = len(input_tokens) + 1
min_ed = min(ed_vision_start, ed_audio_start)
for offset, modality, data in self.iter_mm_features(mm_features):
text_len = offset - st
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
if min_ed == ed_audio_start:
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
if text_len > 0:
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + audio_len + eos_len
audio_idx += 1
remain_audios -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == image_token_id
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = torch.arange(grid_t) * position_id_per_seconds
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + image_len + eos_len
image_idx += 1
remain_images -= 1
elif (
min_ed == ed_vision_start
and input_ids[ed_vision_start + 1] == video_token_id
and not use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
llm_pos_ids_list.append(
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
llm_pos_ids_list.append(llm_pos_ids)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
llm_pos_ids_list.append(
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
st += text_len + bos_len + video_len + eos_len
video_idx += 1
remain_videos -= 1
elif (
min_ed == ed_vision_start
and ed_vision_start + 1 == ed_audio_start
and use_audio_in_video
):
text_len = min_ed - st
if text_len != 0:
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long)
.view(1, -1)
.expand(3, -1)
+ st_idx
)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
bos_len = 1
bos_block = (
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(bos_block)
llm_pos_ids_list.append(bos_block)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
audio_len = _get_feat_extract_output_lengths(
audio_feature_lengths[audio_idx]
)
audio_llm_pos_ids = (
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* float(second_per_grids[video_idx].item())
* position_id_per_seconds
)
video_llm_pos_ids = get_llm_pos_ids_for_vision(
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
video_data_index, audio_data_index = 0, 0
while (
video_data_index < video_llm_pos_ids.shape[-1]
and audio_data_index < audio_llm_pos_ids.shape[-1]
):
if (
video_llm_pos_ids[0][video_data_index]
<= audio_llm_pos_ids[0][audio_data_index]
):
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_data_index + 1
]
)
video_data_index += 1
else:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_data_index + 1
]
)
audio_data_index += 1
if video_data_index < video_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
video_llm_pos_ids[
:, video_data_index : video_llm_pos_ids.shape[-1]
]
)
if audio_data_index < audio_llm_pos_ids.shape[-1]:
llm_pos_ids_list.append(
audio_llm_pos_ids[
:, audio_data_index : audio_llm_pos_ids.shape[-1]
]
)
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
eos_len = 1
eos_block = (
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
)
llm_pos_ids_list.append(eos_block)
llm_pos_ids_list.append(eos_block)
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
audio_idx += 1
video_idx += 1
remain_videos -= 1
remain_audios -= 1
st_idx += text_len
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st
bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(bos_pos)
st_idx += 1
if modality == "audio":
audio_tokens = self._compute_audio_token_count(
data["audio_feature_length"]
)
audio_pos = (
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
)
llm_pos_ids_list.append(audio_pos)
st_idx = int(audio_pos.max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + audio_tokens + 1
elif modality == "image":
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
image_len = grid_t * grid_h * grid_w
st_idx = int(llm_pos_ids_list[-1].max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + image_len + 1
elif modality == "video":
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
if not data["use_audio_in_video"]:
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
video_len = grid_t * grid_h * grid_w
st_idx = int(llm_pos_ids_list[-1].max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
st = offset + 1 + video_len + 1
else:
audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
llm_pos_ids_list.append(audio_bos_pos)
pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
llm_pos_ids_list.append(pos_ids)
st_idx = int(pos_ids.max()) + 1
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
llm_pos_ids_list.append(eos_pos)
llm_pos_ids_list.append(eos_pos)
video_len = grid_t * grid_h * grid_w
audio_len = self._compute_audio_token_count(
data["audio_feature_length"]
)
st = offset + 2 + video_len + audio_len + 2
if st < seq_len:
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
text_len = seq_len - st
llm_pos_ids_list.append(
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
+ st_idx
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
if llm_positions.shape[1] != seq_len:
raise RuntimeError("Position ids length mismatch with input ids length")
mrope_position_delta = llm_positions.max() + 1 - seq_len
return llm_positions, mrope_position_delta
mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
return torch.from_numpy(llm_positions), mrope_position_delta
def get_mm_mapping(self) -> MultiModelKeys:
"""