[Model] Use mm_position to compute mrope positions for Qwen2.5-Omni (#32772)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
Itay Etelis
2026-01-25 14:15:53 +02:00
committed by GitHub
parent 151e5451c2
commit a698e8e7ad
3 changed files with 385 additions and 200 deletions

View File

@@ -112,10 +112,36 @@ def get_multi_audios_query() -> QueryResult:
)
def get_multi_images_query() -> QueryResult:
question = "What are the differences between these two images?"
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n"
)
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"image": [
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
],
},
},
limit_mm_per_prompt={
"image": 2,
},
)
query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
"multi_images": get_multi_images_query,
}

View File

@@ -22,10 +22,11 @@
# limitations under the License.
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Callable, Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal
import numpy as np
import torch
import torch.nn as nn
from transformers import PretrainedConfig
@@ -85,6 +86,7 @@ from vllm.multimodal.processing.processor import (
PlaceholderFeaturesInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -103,7 +105,6 @@ from .utils import (
maybe_prefix,
split_list_into_ranges,
)
from .vision import get_llm_pos_ids_for_vision
try:
import flash_attn
@@ -374,6 +375,67 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self.info.get_hf_config().vision_config.spatial_merge_size
)(hf_inputs)
def _derive_audio_from_video_placeholders(
self,
placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
"""
Helper to derive audio placeholders from video placeholders when
use_audio_in_video=True.
"""
if "video" not in placeholders:
return placeholders
# Validate audio and video counts match
num_videos = len(placeholders["video"])
num_audios = len(mm_prompt_updates.get("audio", []))
if num_audios != num_videos:
raise ValueError(
f"use_audio_in_video requires equal number of audio and video "
f"items, got {num_audios=}, {num_videos=}"
)
tokenizer = self.info.get_tokenizer()
processor = self.info.get_hf_processor()
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
video_token_id = tokenizer.get_vocab()[processor.video_token]
result_placeholders = dict(placeholders)
audio_placeholders = []
video_placeholders = []
# Each video is paired with one audio
for video_idx, video_placeholder in enumerate(placeholders["video"]):
# Create is_embed mask selecting only audio tokens
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
# Create is_embed mask selecting only video tokens
video_is_embed = torch.tensor(video_placeholder.tokens) == video_token_id
audio_placeholder = PlaceholderFeaturesInfo(
modality="audio",
item_idx=video_idx,
start_idx=video_placeholder.start_idx,
tokens=video_placeholder.tokens,
is_embed=audio_is_embed,
)
audio_placeholders.append(audio_placeholder)
# Update video placeholder with is_embed mask
video_placeholder_with_mask = PlaceholderFeaturesInfo(
modality="video",
item_idx=video_idx,
start_idx=video_placeholder.start_idx,
tokens=video_placeholder.tokens,
is_embed=video_is_embed,
)
video_placeholders.append(video_placeholder_with_mask)
result_placeholders["audio"] = audio_placeholders
result_placeholders["video"] = video_placeholders
return result_placeholders
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
@@ -389,6 +451,16 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
# Detect use_audio_in_video from mm_kwargs
use_audio_in_video = False
if "video" in mm_kwargs:
for item in mm_kwargs["video"]:
if item and item.get("use_audio_in_video"):
use_audio_in_video_tensor = item["use_audio_in_video"].data
if use_audio_in_video_tensor.numel() > 0:
use_audio_in_video = bool(use_audio_in_video_tensor.item())
break
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
prompt_ids,
@@ -399,10 +471,25 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
mm_item_counts,
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
if use_audio_in_video and "audio" in mm_prompt_updates:
# Filter out audio updates - they are embedded in video
filtered_updates = {
k: v for k, v in mm_prompt_updates.items() if k != "audio"
}
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
filtered_updates,
)
# Derive audio placeholders from video placeholders
mm_placeholders = self._derive_audio_from_video_placeholders(
mm_placeholders, mm_prompt_updates
)
else:
prompt_ids, mm_placeholders = self._apply_prompt_updates(
prompt_ids,
mm_prompt_updates,
)
self._validate_mm_placeholders(
mm_placeholders,
mm_item_counts,
@@ -542,13 +629,19 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else:
video_second_per_grid_t = 1.0
return self.omni_get_updates_use_audio_in_video(
updates = self.omni_get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
video_second_per_grid_t=video_second_per_grid_t,
)
# Only video tokens should receive video embeddings
return PromptUpdateDetails.select_token_id(
seq=updates,
embed_token_id=video_token_id,
)
video_replacement_fn = (
get_replacement_qwen2_use_audio_in_video
if use_audio_in_video
@@ -889,216 +982,276 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
)
return mm_input_by_modality
def _get_audio_for_video_mapping(
self, mm_features: list[MultiModalFeatureSpec]
) -> tuple[dict[int, int], set[int]]:
"""
Map video offset -> paired audio_feature_length for use_audio_in_video.
When use_audio_in_video=True, audio is interleaved within video chunks.
The pairing is based on feature order in mm_features.
Returns:
Tuple of (video_offset -> audio_feature_length mapping,
set of paired audio offsets to skip)
"""
videos_with_audio = [
f
for f in mm_features
if f.modality == "video"
and f.data.get("use_audio_in_video")
and f.data["use_audio_in_video"].data.item()
]
audios = [f for f in mm_features if f.modality == "audio"]
# Pair videos with audio features (assumes matching order)
mapping: dict[int, int] = {}
paired_audio_offsets: set[int] = set()
for i, video_f in enumerate(videos_with_audio):
if i < len(audios):
audio_len = audios[i].data["audio_feature_lengths"].data.item()
mapping[video_f.mm_position.offset] = audio_len
paired_audio_offsets.add(audios[i].mm_position.offset)
return mapping, paired_audio_offsets
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
"""Compute audio tokens from feature length."""
return ((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1
def iter_mm_features(
self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
"""
Iterate over multimodal features sorted by position offset.
Yields: (offset, modality, feature_data) where feature_data contains:
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
"use_audio_in_video", "audio_feature_length"}
- audio: {"audio_feature_length"}
"""
thinker_config = self.config
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(
thinker_config.vision_config, "tokens_per_second", 25
)
# Sort features by offset first, then pair audio with video
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
sorted_features
)
for mm_feature in sorted_features:
offset = mm_feature.mm_position.offset
modality = mm_feature.modality
if modality == "image":
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
yield (
offset,
"image",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": 1.0 * tokens_per_second,
},
)
elif modality == "video":
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
second_per_grid_ts = 1.0
if mm_feature.data.get("second_per_grid_ts"):
second_per_grid_ts = mm_feature.data[
"second_per_grid_ts"
].data.item()
use_audio_in_video = False
if mm_feature.data.get("use_audio_in_video"):
use_audio_in_video = bool(
mm_feature.data["use_audio_in_video"].data.item()
)
yield (
offset,
"video",
{
"grid_t": t,
"grid_h": h // spatial_merge_size,
"grid_w": w // spatial_merge_size,
"t_factor": second_per_grid_ts * tokens_per_second,
"use_audio_in_video": use_audio_in_video,
"audio_feature_length": audio_for_video.get(offset),
},
)
elif modality == "audio":
# Skip audio that's paired with video (handled in video case)
if offset not in paired_audio_offsets:
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
yield offset, "audio", {"audio_feature_length": audio_len}
def _compute_interleaved_positions(
self, start_idx: int, data: dict[str, Any]
) -> tuple[np.ndarray, int]:
"""
Compute positions for interleaved video+audio chunks.
Returns: (position_ids, total_token_count)
"""
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
audio_len = data["audio_feature_length"]
thinker_config = self.config
tokens_per_second = getattr(
thinker_config.vision_config, "tokens_per_second", 25
)
seconds_per_chunk = thinker_config.seconds_per_chunk
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
# Temporal indices with scaling
t_index = (np.arange(grid_t) * t_factor).astype(np.int64)
# Split temporal indices into chunks
t_index_split_chunk: list[list[int]] = [
[] for _ in range((int(t_index.max()) // t_ntoken_per_chunk) + 1)
]
for t_val in t_index:
idx = int(t_val) // t_ntoken_per_chunk
t_index_split_chunk[idx].append(int(t_val))
pure_audio_len = self._compute_audio_token_count(audio_len)
added_audio_len = 0
pos_ids_list: list[np.ndarray] = []
audio_start_idx = start_idx
for t_chunk in t_index_split_chunk:
if not t_chunk:
continue
chunk_t = len(t_chunk)
# Build vision positions for this chunk
h_indices = np.tile(
np.arange(grid_h).reshape(1, -1, 1), (chunk_t, 1, grid_w)
).flatten()
w_indices = np.tile(
np.arange(grid_w).reshape(1, 1, -1), (chunk_t, grid_h, 1)
).flatten()
t_indices = np.repeat(np.array(t_chunk), grid_h * grid_w)
vision_pos = np.stack([t_indices, h_indices, w_indices]) + start_idx
pos_ids_list.append(vision_pos)
# Audio tokens for this chunk
audio_chunk_size = min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
if audio_chunk_size > 0:
audio_pos = (
np.broadcast_to(np.arange(audio_chunk_size), (3, audio_chunk_size))
+ audio_start_idx
)
pos_ids_list.append(audio_pos)
audio_start_idx = audio_start_idx + audio_chunk_size
added_audio_len += audio_chunk_size
# Handle remaining audio that doesn't fit in chunks
if added_audio_len < pure_audio_len:
remaining = pure_audio_len - added_audio_len
remaining_audio_pos = (
np.broadcast_to(np.arange(remaining), (3, remaining)) + audio_start_idx
)
pos_ids_list.append(remaining_audio_pos)
# Calculate total token count
vision_tokens = grid_t * grid_h * grid_w
total_tokens = vision_tokens + pure_audio_len
return np.concatenate(pos_ids_list, axis=1), total_tokens
def get_mrope_input_positions(
self,
input_tokens: list[int],
mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
"""
Example:
Compute M-RoPE input positions using mm_features directly.
Example for use_audio_in_video case:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
kwargs = MultiModalFeatureSpec.gather_kwargs(
mm_features,
{
"image_grid_thw",
"video_grid_thw",
"second_per_grid_ts",
"audio_feature_lengths",
"use_audio_in_video",
},
)
image_grid_thw = kwargs.get("image_grid_thw", [])
video_grid_thw = kwargs.get("video_grid_thw", [])
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
llm_pos_ids_list: list[np.ndarray] = []
st = 0
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
image_grid_thw
)
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
video_grid_thw
)
for offset, modality, data in self.iter_mm_features(mm_features):
# Add text segment before this feature
text_len = offset - st
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
if text_len > 0:
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
st_idx += text_len
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
if modality == "audio":
# Standalone audio positions
audio_tokens = self._compute_audio_token_count(
data["audio_feature_length"]
)
llm_pos_ids_list.append(
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
)
st = offset + audio_tokens
thinker_config = self.config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(
thinker_config.vision_config, "tokens_per_second", 25
)
elif modality == "image":
# Image uses np.indices like Qwen2-VL
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
st = offset + grid_t * grid_h * grid_w
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
elif modality == "video":
grid_t = data["grid_t"]
grid_h = data["grid_h"]
grid_w = data["grid_w"]
t_factor = data["t_factor"]
if not data["use_audio_in_video"]:
# Simple video (same as Qwen2-VL)
grid_indices = np.indices((grid_t, grid_h, grid_w))
if t_factor != 1.0:
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
st = offset + grid_t * grid_h * grid_w
else:
# Interleaved video+audio
pos_ids, token_count = self._compute_interleaved_positions(
st_idx, data
)
llm_pos_ids_list.append(pos_ids)
st = offset + token_count
# Add trailing text
if st < len(input_tokens):
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
)
if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
if use_audio_in_video and idx > 0:
if (
src_item[idx] == vision_end_token_id
and src_item[idx - 1] == audio_end_token_id
):
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif (
src_item[idx] == audio_start_token_id
and src_item[idx - 1] == vision_start_token_id
):
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grid_ts[video_idx]
* tokens_per_second
).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (
torch.arange(grid_t)
* second_per_grid_ts[video_idx]
* tokens_per_second
).long()
t_index_split_chunk = split_list_into_ranges(
t_index, t_ntoken_per_chunk
)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = (
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
)
new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
start_idx,
video_idx,
spatial_merge_size,
t_chunk,
grid_hs,
grid_ws,
).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
* [audio_token_id]
)
audio_start_idx = (
start_idx
if len(audio_llm_pos_ids_list) == 0
else audio_llm_pos_ids_list[-1][0].item() + 1
)
if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (
torch.arange(
min(
t_ntoken_per_chunk, pure_audio_len - added_audio_len
)
).expand(3, -1)
+ audio_start_idx
).split(1, dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(
t_ntoken_per_chunk, pure_audio_len - added_audio_len
)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id]
)
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
+ llm_pos_ids_list[-1].max()
+ 1
).split(1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = (
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
)
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
mrope_position_delta = int(llm_positions.max()) + 1 - len(input_tokens)
return llm_positions, mrope_position_delta
return torch.from_numpy(llm_positions), mrope_position_delta
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)

View File

@@ -2474,9 +2474,15 @@ class GPUModelRunner(
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
# OR mask for overlapping mm_features (use_audio_in_video)
if is_embed is None:
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True
)
else:
is_mm_embed[
req_start_pos + start_idx : req_start_pos + end_idx
] |= is_embed
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope: