[Model] Use mm_position to compute mrope positions for Qwen2.5-Omni (#32772)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
@@ -112,10 +112,36 @@ def get_multi_audios_query() -> QueryResult:
|
||||
)
|
||||
|
||||
|
||||
def get_multi_images_query() -> QueryResult:
|
||||
question = "What are the differences between these two images?"
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||
"<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [
|
||||
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
|
||||
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
|
||||
],
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
query_map = {
|
||||
"mixed_modalities": get_mixed_modalities_query,
|
||||
"use_audio_in_video": get_use_audio_in_video_query,
|
||||
"multi_audios": get_multi_audios_query,
|
||||
"multi_images": get_multi_images_query,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -22,10 +22,11 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2.5-Omni model (thinker part)."""
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
@@ -85,6 +86,7 @@ from vllm.multimodal.processing.processor import (
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -103,7 +105,6 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
split_list_into_ranges,
|
||||
)
|
||||
from .vision import get_llm_pos_ids_for_vision
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@@ -374,6 +375,67 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size
|
||||
)(hf_inputs)
|
||||
|
||||
def _derive_audio_from_video_placeholders(
|
||||
self,
|
||||
placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
"""
|
||||
Helper to derive audio placeholders from video placeholders when
|
||||
use_audio_in_video=True.
|
||||
"""
|
||||
if "video" not in placeholders:
|
||||
return placeholders
|
||||
|
||||
# Validate audio and video counts match
|
||||
num_videos = len(placeholders["video"])
|
||||
num_audios = len(mm_prompt_updates.get("audio", []))
|
||||
if num_audios != num_videos:
|
||||
raise ValueError(
|
||||
f"use_audio_in_video requires equal number of audio and video "
|
||||
f"items, got {num_audios=}, {num_videos=}"
|
||||
)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
processor = self.info.get_hf_processor()
|
||||
audio_token_id = tokenizer.get_vocab()[processor.audio_token]
|
||||
video_token_id = tokenizer.get_vocab()[processor.video_token]
|
||||
|
||||
result_placeholders = dict(placeholders)
|
||||
audio_placeholders = []
|
||||
video_placeholders = []
|
||||
|
||||
# Each video is paired with one audio
|
||||
for video_idx, video_placeholder in enumerate(placeholders["video"]):
|
||||
# Create is_embed mask selecting only audio tokens
|
||||
audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id
|
||||
|
||||
# Create is_embed mask selecting only video tokens
|
||||
video_is_embed = torch.tensor(video_placeholder.tokens) == video_token_id
|
||||
|
||||
audio_placeholder = PlaceholderFeaturesInfo(
|
||||
modality="audio",
|
||||
item_idx=video_idx,
|
||||
start_idx=video_placeholder.start_idx,
|
||||
tokens=video_placeholder.tokens,
|
||||
is_embed=audio_is_embed,
|
||||
)
|
||||
audio_placeholders.append(audio_placeholder)
|
||||
|
||||
# Update video placeholder with is_embed mask
|
||||
video_placeholder_with_mask = PlaceholderFeaturesInfo(
|
||||
modality="video",
|
||||
item_idx=video_idx,
|
||||
start_idx=video_placeholder.start_idx,
|
||||
tokens=video_placeholder.tokens,
|
||||
is_embed=video_is_embed,
|
||||
)
|
||||
video_placeholders.append(video_placeholder_with_mask)
|
||||
|
||||
result_placeholders["audio"] = audio_placeholders
|
||||
result_placeholders["video"] = video_placeholders
|
||||
return result_placeholders
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@@ -389,6 +451,16 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
self._validate_mm_updates(mm_prompt_updates, mm_item_counts)
|
||||
|
||||
# Detect use_audio_in_video from mm_kwargs
|
||||
use_audio_in_video = False
|
||||
if "video" in mm_kwargs:
|
||||
for item in mm_kwargs["video"]:
|
||||
if item and item.get("use_audio_in_video"):
|
||||
use_audio_in_video_tensor = item["use_audio_in_video"].data
|
||||
if use_audio_in_video_tensor.numel() > 0:
|
||||
use_audio_in_video = bool(use_audio_in_video_tensor.item())
|
||||
break
|
||||
|
||||
if is_update_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
prompt_ids,
|
||||
@@ -399,10 +471,25 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
mm_item_counts,
|
||||
)
|
||||
else:
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
)
|
||||
if use_audio_in_video and "audio" in mm_prompt_updates:
|
||||
# Filter out audio updates - they are embedded in video
|
||||
filtered_updates = {
|
||||
k: v for k, v in mm_prompt_updates.items() if k != "audio"
|
||||
}
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
filtered_updates,
|
||||
)
|
||||
# Derive audio placeholders from video placeholders
|
||||
mm_placeholders = self._derive_audio_from_video_placeholders(
|
||||
mm_placeholders, mm_prompt_updates
|
||||
)
|
||||
else:
|
||||
prompt_ids, mm_placeholders = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_updates,
|
||||
)
|
||||
|
||||
self._validate_mm_placeholders(
|
||||
mm_placeholders,
|
||||
mm_item_counts,
|
||||
@@ -542,13 +629,19 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
else:
|
||||
video_second_per_grid_t = 1.0
|
||||
|
||||
return self.omni_get_updates_use_audio_in_video(
|
||||
updates = self.omni_get_updates_use_audio_in_video(
|
||||
thinker_config=thinker_config,
|
||||
audio_len=audio_num_features,
|
||||
video_grid_thw=video_grid_thw,
|
||||
video_second_per_grid_t=video_second_per_grid_t,
|
||||
)
|
||||
|
||||
# Only video tokens should receive video embeddings
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
seq=updates,
|
||||
embed_token_id=video_token_id,
|
||||
)
|
||||
|
||||
video_replacement_fn = (
|
||||
get_replacement_qwen2_use_audio_in_video
|
||||
if use_audio_in_video
|
||||
@@ -889,216 +982,276 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
)
|
||||
return mm_input_by_modality
|
||||
|
||||
def _get_audio_for_video_mapping(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> tuple[dict[int, int], set[int]]:
|
||||
"""
|
||||
Map video offset -> paired audio_feature_length for use_audio_in_video.
|
||||
|
||||
When use_audio_in_video=True, audio is interleaved within video chunks.
|
||||
The pairing is based on feature order in mm_features.
|
||||
|
||||
Returns:
|
||||
Tuple of (video_offset -> audio_feature_length mapping,
|
||||
set of paired audio offsets to skip)
|
||||
"""
|
||||
videos_with_audio = [
|
||||
f
|
||||
for f in mm_features
|
||||
if f.modality == "video"
|
||||
and f.data.get("use_audio_in_video")
|
||||
and f.data["use_audio_in_video"].data.item()
|
||||
]
|
||||
audios = [f for f in mm_features if f.modality == "audio"]
|
||||
|
||||
# Pair videos with audio features (assumes matching order)
|
||||
mapping: dict[int, int] = {}
|
||||
paired_audio_offsets: set[int] = set()
|
||||
for i, video_f in enumerate(videos_with_audio):
|
||||
if i < len(audios):
|
||||
audio_len = audios[i].data["audio_feature_lengths"].data.item()
|
||||
mapping[video_f.mm_position.offset] = audio_len
|
||||
paired_audio_offsets.add(audios[i].mm_position.offset)
|
||||
return mapping, paired_audio_offsets
|
||||
|
||||
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
|
||||
"""Compute audio tokens from feature length."""
|
||||
return ((audio_feature_length - 1) // 2 + 1 - 2) // 2 + 1
|
||||
|
||||
def iter_mm_features(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> Iterator[tuple[int, str, dict[str, Any]]]:
|
||||
"""
|
||||
Iterate over multimodal features sorted by position offset.
|
||||
|
||||
Yields: (offset, modality, feature_data) where feature_data contains:
|
||||
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
|
||||
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
|
||||
"use_audio_in_video", "audio_feature_length"}
|
||||
- audio: {"audio_feature_length"}
|
||||
"""
|
||||
thinker_config = self.config
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
# Sort features by offset first, then pair audio with video
|
||||
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
|
||||
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
|
||||
sorted_features
|
||||
)
|
||||
|
||||
for mm_feature in sorted_features:
|
||||
offset = mm_feature.mm_position.offset
|
||||
modality = mm_feature.modality
|
||||
|
||||
if modality == "image":
|
||||
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
|
||||
yield (
|
||||
offset,
|
||||
"image",
|
||||
{
|
||||
"grid_t": t,
|
||||
"grid_h": h // spatial_merge_size,
|
||||
"grid_w": w // spatial_merge_size,
|
||||
"t_factor": 1.0 * tokens_per_second,
|
||||
},
|
||||
)
|
||||
elif modality == "video":
|
||||
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
||||
second_per_grid_ts = 1.0
|
||||
if mm_feature.data.get("second_per_grid_ts"):
|
||||
second_per_grid_ts = mm_feature.data[
|
||||
"second_per_grid_ts"
|
||||
].data.item()
|
||||
use_audio_in_video = False
|
||||
if mm_feature.data.get("use_audio_in_video"):
|
||||
use_audio_in_video = bool(
|
||||
mm_feature.data["use_audio_in_video"].data.item()
|
||||
)
|
||||
|
||||
yield (
|
||||
offset,
|
||||
"video",
|
||||
{
|
||||
"grid_t": t,
|
||||
"grid_h": h // spatial_merge_size,
|
||||
"grid_w": w // spatial_merge_size,
|
||||
"t_factor": second_per_grid_ts * tokens_per_second,
|
||||
"use_audio_in_video": use_audio_in_video,
|
||||
"audio_feature_length": audio_for_video.get(offset),
|
||||
},
|
||||
)
|
||||
elif modality == "audio":
|
||||
# Skip audio that's paired with video (handled in video case)
|
||||
if offset not in paired_audio_offsets:
|
||||
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
|
||||
yield offset, "audio", {"audio_feature_length": audio_len}
|
||||
|
||||
def _compute_interleaved_positions(
|
||||
self, start_idx: int, data: dict[str, Any]
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Compute positions for interleaved video+audio chunks.
|
||||
|
||||
Returns: (position_ids, total_token_count)
|
||||
"""
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
audio_len = data["audio_feature_length"]
|
||||
|
||||
thinker_config = self.config
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
|
||||
# Temporal indices with scaling
|
||||
t_index = (np.arange(grid_t) * t_factor).astype(np.int64)
|
||||
|
||||
# Split temporal indices into chunks
|
||||
t_index_split_chunk: list[list[int]] = [
|
||||
[] for _ in range((int(t_index.max()) // t_ntoken_per_chunk) + 1)
|
||||
]
|
||||
for t_val in t_index:
|
||||
idx = int(t_val) // t_ntoken_per_chunk
|
||||
t_index_split_chunk[idx].append(int(t_val))
|
||||
|
||||
pure_audio_len = self._compute_audio_token_count(audio_len)
|
||||
added_audio_len = 0
|
||||
pos_ids_list: list[np.ndarray] = []
|
||||
audio_start_idx = start_idx
|
||||
|
||||
for t_chunk in t_index_split_chunk:
|
||||
if not t_chunk:
|
||||
continue
|
||||
|
||||
chunk_t = len(t_chunk)
|
||||
|
||||
# Build vision positions for this chunk
|
||||
h_indices = np.tile(
|
||||
np.arange(grid_h).reshape(1, -1, 1), (chunk_t, 1, grid_w)
|
||||
).flatten()
|
||||
w_indices = np.tile(
|
||||
np.arange(grid_w).reshape(1, 1, -1), (chunk_t, grid_h, 1)
|
||||
).flatten()
|
||||
t_indices = np.repeat(np.array(t_chunk), grid_h * grid_w)
|
||||
|
||||
vision_pos = np.stack([t_indices, h_indices, w_indices]) + start_idx
|
||||
pos_ids_list.append(vision_pos)
|
||||
|
||||
# Audio tokens for this chunk
|
||||
audio_chunk_size = min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
|
||||
if audio_chunk_size > 0:
|
||||
audio_pos = (
|
||||
np.broadcast_to(np.arange(audio_chunk_size), (3, audio_chunk_size))
|
||||
+ audio_start_idx
|
||||
)
|
||||
pos_ids_list.append(audio_pos)
|
||||
audio_start_idx = audio_start_idx + audio_chunk_size
|
||||
added_audio_len += audio_chunk_size
|
||||
|
||||
# Handle remaining audio that doesn't fit in chunks
|
||||
if added_audio_len < pure_audio_len:
|
||||
remaining = pure_audio_len - added_audio_len
|
||||
remaining_audio_pos = (
|
||||
np.broadcast_to(np.arange(remaining), (3, remaining)) + audio_start_idx
|
||||
)
|
||||
pos_ids_list.append(remaining_audio_pos)
|
||||
|
||||
# Calculate total token count
|
||||
vision_tokens = grid_t * grid_h * grid_w
|
||||
total_tokens = vision_tokens + pure_audio_len
|
||||
|
||||
return np.concatenate(pos_ids_list, axis=1), total_tokens
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""
|
||||
Example:
|
||||
Compute M-RoPE input positions using mm_features directly.
|
||||
|
||||
Example for use_audio_in_video case:
|
||||
(V_i are vision position ids, A_i are audio position ids)
|
||||
|
||||
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|
||||
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
|
||||
"""
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"audio_feature_lengths",
|
||||
"use_audio_in_video",
|
||||
},
|
||||
)
|
||||
image_grid_thw = kwargs.get("image_grid_thw", [])
|
||||
video_grid_thw = kwargs.get("video_grid_thw", [])
|
||||
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
|
||||
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
|
||||
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
|
||||
llm_pos_ids_list: list[np.ndarray] = []
|
||||
st = 0
|
||||
|
||||
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
|
||||
image_grid_thw
|
||||
)
|
||||
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
|
||||
video_grid_thw
|
||||
)
|
||||
for offset, modality, data in self.iter_mm_features(mm_features):
|
||||
# Add text segment before this feature
|
||||
text_len = offset - st
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
|
||||
if text_len > 0:
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
st_idx += text_len
|
||||
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
if modality == "audio":
|
||||
# Standalone audio positions
|
||||
audio_tokens = self._compute_audio_token_count(
|
||||
data["audio_feature_length"]
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
|
||||
)
|
||||
st = offset + audio_tokens
|
||||
|
||||
thinker_config = self.config
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
audio_start_token_id = thinker_config.audio_start_token_id
|
||||
audio_end_token_id = thinker_config.audio_end_token_id
|
||||
vision_start_token_id = thinker_config.vision_start_token_id
|
||||
vision_end_token_id = thinker_config.vision_end_token_id
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
elif modality == "image":
|
||||
# Image uses np.indices like Qwen2-VL
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
|
||||
src_item = input_tokens
|
||||
audio_seqlens = audio_feature_lengths
|
||||
if not second_per_grid_ts:
|
||||
second_per_grid_ts = [1] * video_grid_thw.shape[0]
|
||||
audio_idx = 0
|
||||
video_idx = 0
|
||||
image_idx = 0
|
||||
new_src_item: list[int] = []
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
grid_indices = np.indices((grid_t, grid_h, grid_w))
|
||||
if t_factor != 1.0:
|
||||
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
|
||||
st = offset + grid_t * grid_h * grid_w
|
||||
|
||||
idx = 0
|
||||
while idx < len(src_item):
|
||||
new_src_item_len = len(new_src_item)
|
||||
start_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
elif modality == "video":
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
|
||||
if not data["use_audio_in_video"]:
|
||||
# Simple video (same as Qwen2-VL)
|
||||
grid_indices = np.indices((grid_t, grid_h, grid_w))
|
||||
if t_factor != 1.0:
|
||||
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
|
||||
st = offset + grid_t * grid_h * grid_w
|
||||
else:
|
||||
# Interleaved video+audio
|
||||
pos_ids, token_count = self._compute_interleaved_positions(
|
||||
st_idx, data
|
||||
)
|
||||
llm_pos_ids_list.append(pos_ids)
|
||||
st = offset + token_count
|
||||
|
||||
# Add trailing text
|
||||
if st < len(input_tokens):
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
|
||||
if use_audio_in_video and idx > 0:
|
||||
if (
|
||||
src_item[idx] == vision_end_token_id
|
||||
and src_item[idx - 1] == audio_end_token_id
|
||||
):
|
||||
# processing the <|audio_eos|> before <|vision_eos|>
|
||||
start_idx -= 1
|
||||
elif (
|
||||
src_item[idx] == audio_start_token_id
|
||||
and src_item[idx - 1] == vision_start_token_id
|
||||
):
|
||||
# processing the <|audio_bos|> after <|vision_eos|>
|
||||
start_idx -= 1
|
||||
new_src_item.append(src_item[idx])
|
||||
llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
elif src_item[idx] == audio_token_id:
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
|
||||
new_src_item.extend([audio_token_id] * place_num)
|
||||
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
audio_idx += 1
|
||||
elif src_item[idx] == image_token_id:
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = image_grid_thw[image_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([image_token_id] * vision_seqlen)
|
||||
image_idx += 1
|
||||
elif src_item[idx] == video_token_id and not use_audio_in_video:
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_seqlen)
|
||||
video_idx += 1
|
||||
else:
|
||||
# read audio from video
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_h = video_grid_thw[video_idx][1]
|
||||
grid_w = video_grid_thw[video_idx][2]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
t_index_split_chunk = split_list_into_ranges(
|
||||
t_index, t_ntoken_per_chunk
|
||||
)
|
||||
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
|
||||
pure_audio_len = place_num - 2
|
||||
added_audio_len = 0
|
||||
audio_llm_pos_ids_list: list[torch.Tensor] = []
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = (
|
||||
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||
vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
|
||||
start_idx,
|
||||
video_idx,
|
||||
spatial_merge_size,
|
||||
t_chunk,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
|
||||
new_src_item.extend(
|
||||
min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
|
||||
* [audio_token_id]
|
||||
)
|
||||
audio_start_idx = (
|
||||
start_idx
|
||||
if len(audio_llm_pos_ids_list) == 0
|
||||
else audio_llm_pos_ids_list[-1][0].item() + 1
|
||||
)
|
||||
if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(
|
||||
min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
).expand(3, -1)
|
||||
+ audio_start_idx
|
||||
).split(1, dim=1)
|
||||
else:
|
||||
audio_llm_pos_ids_list = []
|
||||
added_audio_len += min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
if added_audio_len < pure_audio_len:
|
||||
new_src_item.extend(
|
||||
(pure_audio_len - added_audio_len) * [audio_token_id]
|
||||
)
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
|
||||
+ llm_pos_ids_list[-1].max()
|
||||
+ 1
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
# move to the next token
|
||||
idx += len(new_src_item) - new_src_item_len
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
|
||||
mrope_position_delta = (
|
||||
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
|
||||
)
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
mrope_position_delta = int(llm_positions.max()) + 1 - len(input_tokens)
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
|
||||
@@ -2474,9 +2474,15 @@ class GPUModelRunner(
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
# OR mask for overlapping mm_features (use_audio_in_video)
|
||||
if is_embed is None:
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True
|
||||
)
|
||||
else:
|
||||
is_mm_embed[
|
||||
req_start_pos + start_idx : req_start_pos + end_idx
|
||||
] |= is_embed
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
|
||||
Reference in New Issue
Block a user