[Model] Use mm_position to compute mrope positions for Qwen3-Omni (#33010)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
This example shows how to use vLLM for running offline inference
|
||||
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||
with the correct prompt format on Qwen3-Omni (thinker only).
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
@@ -112,23 +112,51 @@ def get_multi_audios_query() -> QueryResult:
|
||||
)
|
||||
|
||||
|
||||
def get_multi_images_query() -> QueryResult:
|
||||
question = "What are the differences between these two images?"
|
||||
prompt = (
|
||||
f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||
f"{question}<|im_end|>\n"
|
||||
f"<|im_start|>assistant\n"
|
||||
)
|
||||
return QueryResult(
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {
|
||||
"image": [
|
||||
convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB"),
|
||||
convert_image_mode(ImageAsset("stop_sign").pil_image, "RGB"),
|
||||
],
|
||||
},
|
||||
},
|
||||
limit_mm_per_prompt={
|
||||
"image": 2,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
query_map = {
|
||||
"mixed_modalities": get_mixed_modalities_query,
|
||||
"use_audio_in_video": get_use_audio_in_video_query,
|
||||
"multi_audios": get_multi_audios_query,
|
||||
"multi_images": get_multi_images_query,
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
model_name = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
|
||||
model_name = args.model
|
||||
query_result = query_map[args.query_type]()
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
max_model_len=12800,
|
||||
max_model_len=args.max_model_len,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||
seed=args.seed,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
gpu_memory_utilization=args.gpu_memory_utilization,
|
||||
)
|
||||
|
||||
# We set temperature to 0.2 so that outputs can be different
|
||||
@@ -161,6 +189,31 @@ def parse_args():
|
||||
default=0,
|
||||
help="Set the seed when initializing `vllm.LLM`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="Qwen/Qwen3-Omni-30B-A3B-Instruct",
|
||||
help="Model name or path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tensor-parallel-size",
|
||||
"-tp",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Tensor parallel size for distributed inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-memory-utilization",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="GPU memory utilization (0.0 to 1.0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=12800,
|
||||
help="Maximum model context length.",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
|
||||
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@@ -104,10 +104,7 @@ from .utils import (
|
||||
_merge_multimodal_embeddings,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
get_llm_pos_ids_for_vision,
|
||||
get_vit_attn_backend,
|
||||
)
|
||||
from .vision import get_vit_attn_backend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -1867,323 +1864,268 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
|
||||
return loaded_weights
|
||||
|
||||
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
|
||||
"""Compute audio tokens from feature length using Qwen3-Omni formula."""
|
||||
return _get_feat_extract_output_lengths(
|
||||
torch.tensor([audio_feature_length])
|
||||
).item()
|
||||
|
||||
def _get_audio_for_video_mapping(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> tuple[dict[int, int], set[int]]:
|
||||
"""
|
||||
Map video offset -> paired audio_feature_length for use_audio_in_video.
|
||||
|
||||
When use_audio_in_video=True, audio is interleaved within video.
|
||||
The pairing is based on feature order in mm_features.
|
||||
|
||||
Returns:
|
||||
Tuple of (video_offset -> audio_feature_length mapping,
|
||||
set of paired audio offsets to skip)
|
||||
"""
|
||||
videos_with_audio = [
|
||||
f
|
||||
for f in mm_features
|
||||
if f.modality == "video"
|
||||
and f.data.get("use_audio_in_video")
|
||||
and f.data["use_audio_in_video"].data.item()
|
||||
]
|
||||
audios = [f for f in mm_features if f.modality == "audio"]
|
||||
|
||||
mapping: dict[int, int] = {}
|
||||
paired_audio_offsets: set[int] = set()
|
||||
for i, video_f in enumerate(videos_with_audio):
|
||||
if i < len(audios):
|
||||
audio_len = audios[i].data["audio_feature_lengths"].data.item()
|
||||
mapping[video_f.mm_position.offset] = audio_len
|
||||
paired_audio_offsets.add(audios[i].mm_position.offset)
|
||||
return mapping, paired_audio_offsets
|
||||
|
||||
def iter_mm_features(
|
||||
self, mm_features: list[MultiModalFeatureSpec]
|
||||
) -> Iterator[tuple[int, str, dict[str, Any]]]:
|
||||
"""
|
||||
Iterate over multimodal features sorted by position offset.
|
||||
|
||||
Yields: (offset, modality, feature_data) where feature_data contains:
|
||||
- image: {"grid_t", "grid_h", "grid_w", "t_factor"}
|
||||
- video: {"grid_t", "grid_h", "grid_w", "t_factor",
|
||||
"use_audio_in_video", "audio_feature_length"}
|
||||
- audio: {"audio_feature_length"}
|
||||
"""
|
||||
config = self.config
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
|
||||
sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
|
||||
audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
|
||||
sorted_features
|
||||
)
|
||||
|
||||
for mm_feature in sorted_features:
|
||||
offset = mm_feature.mm_position.offset
|
||||
modality = mm_feature.modality
|
||||
|
||||
if modality == "image":
|
||||
t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
|
||||
yield (
|
||||
offset,
|
||||
"image",
|
||||
{
|
||||
"grid_t": t,
|
||||
"grid_h": h // spatial_merge_size,
|
||||
"grid_w": w // spatial_merge_size,
|
||||
"t_factor": position_id_per_seconds,
|
||||
},
|
||||
)
|
||||
elif modality == "video":
|
||||
t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
|
||||
second_per_grid_ts = 2.0
|
||||
if mm_feature.data.get("second_per_grid_ts"):
|
||||
second_per_grid_ts = mm_feature.data[
|
||||
"second_per_grid_ts"
|
||||
].data.item()
|
||||
use_audio_in_video = bool(
|
||||
mm_feature.data.get("use_audio_in_video")
|
||||
and mm_feature.data["use_audio_in_video"].data.item()
|
||||
)
|
||||
|
||||
yield (
|
||||
offset,
|
||||
"video",
|
||||
{
|
||||
"grid_t": t,
|
||||
"grid_h": h // spatial_merge_size,
|
||||
"grid_w": w // spatial_merge_size,
|
||||
"t_factor": second_per_grid_ts * position_id_per_seconds,
|
||||
"use_audio_in_video": use_audio_in_video,
|
||||
"audio_feature_length": audio_for_video.get(offset),
|
||||
},
|
||||
)
|
||||
elif modality == "audio":
|
||||
if offset not in paired_audio_offsets:
|
||||
audio_len = mm_feature.data["audio_feature_lengths"].data.item()
|
||||
yield offset, "audio", {"audio_feature_length": audio_len}
|
||||
|
||||
def _compute_interleaved_positions(
|
||||
self, start_idx: int, data: dict[str, Any]
|
||||
) -> tuple[np.ndarray, int]:
|
||||
"""
|
||||
Compute positions for interleaved video+audio using Qwen3 token-by-token
|
||||
interleaving logic.
|
||||
|
||||
Returns: (position_ids [3, N], total_token_count)
|
||||
"""
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
audio_feature_length = data["audio_feature_length"]
|
||||
|
||||
audio_len = self._compute_audio_token_count(audio_feature_length)
|
||||
|
||||
h_index = np.tile(
|
||||
np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
|
||||
).flatten()
|
||||
w_index = np.tile(
|
||||
np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
|
||||
).flatten()
|
||||
t_index_raw = np.arange(grid_t)
|
||||
t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
|
||||
t_index = np.repeat(t_index_scaled, grid_h * grid_w)
|
||||
|
||||
video_pos = np.stack([t_index, h_index, w_index]) + start_idx
|
||||
audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx
|
||||
|
||||
video_t_values = video_pos[0]
|
||||
audio_t_values = audio_pos[0]
|
||||
|
||||
pos_ids_list: list[np.ndarray] = []
|
||||
video_idx, audio_idx = 0, 0
|
||||
num_video = grid_t * grid_h * grid_w
|
||||
|
||||
while video_idx < num_video and audio_idx < audio_len:
|
||||
if video_t_values[video_idx] <= audio_t_values[audio_idx]:
|
||||
pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
|
||||
video_idx += 1
|
||||
else:
|
||||
pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
|
||||
audio_idx += 1
|
||||
|
||||
if video_idx < num_video:
|
||||
pos_ids_list.append(video_pos[:, video_idx:])
|
||||
if audio_idx < audio_len:
|
||||
pos_ids_list.append(audio_pos[:, audio_idx:])
|
||||
|
||||
total_tokens = num_video + audio_len
|
||||
return np.concatenate(pos_ids_list, axis=1), total_tokens
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
mm_features: list[MultiModalFeatureSpec],
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
kwargs = MultiModalFeatureSpec.gather_kwargs(
|
||||
mm_features,
|
||||
{
|
||||
"image_grid_thw",
|
||||
"video_grid_thw",
|
||||
"second_per_grid_ts",
|
||||
"audio_feature_lengths",
|
||||
"use_audio_in_video",
|
||||
},
|
||||
)
|
||||
image_grid_thw = kwargs.get("image_grid_thw", [])
|
||||
video_grid_thw = kwargs.get("video_grid_thw", [])
|
||||
second_per_grid_ts = kwargs.get("second_per_grid_ts", [])
|
||||
audio_feature_lengths = kwargs.get("audio_feature_lengths", [])
|
||||
use_audio_in_video = any(kwargs.get("use_audio_in_video", []))
|
||||
"""Compute M-RoPE input positions using mm_features directly."""
|
||||
seq_len = len(input_tokens)
|
||||
|
||||
image_grid_thw = (torch.stack if image_grid_thw else torch.tensor)(
|
||||
image_grid_thw
|
||||
)
|
||||
video_grid_thw = (torch.stack if video_grid_thw else torch.tensor)(
|
||||
video_grid_thw
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(input_tokens)
|
||||
if input_ids is None or input_ids.ndim != 1:
|
||||
raise ValueError("_omni3_get_input_positions_tensor expects 1D input_ids")
|
||||
|
||||
seq_len = input_ids.shape[0]
|
||||
|
||||
if isinstance(audio_feature_lengths, list):
|
||||
audio_feature_lengths = torch.tensor(
|
||||
audio_feature_lengths, dtype=torch.long
|
||||
)
|
||||
|
||||
if not len(second_per_grid_ts) and len(video_grid_thw):
|
||||
second_per_grid_ts = 2.0
|
||||
second_per_grids = (
|
||||
torch.ones(len(video_grid_thw), dtype=torch.float32)
|
||||
* second_per_grid_ts
|
||||
)
|
||||
else:
|
||||
second_per_grids = torch.tensor(second_per_grid_ts, dtype=torch.float32)
|
||||
|
||||
config = self.config
|
||||
spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
image_token_id = config.image_token_id
|
||||
video_token_id = config.video_token_id
|
||||
audio_token_id = config.audio_token_id
|
||||
vision_start_token_id = config.vision_start_token_id
|
||||
audio_start_token_id = config.audio_start_token_id
|
||||
position_id_per_seconds = config.position_id_per_seconds
|
||||
|
||||
vision_start_indices = torch.argwhere(
|
||||
input_ids == vision_start_token_id
|
||||
).squeeze(1)
|
||||
if vision_start_indices.numel() > 0:
|
||||
vision_tokens = input_ids[vision_start_indices + 1]
|
||||
else:
|
||||
vision_tokens = input_ids.new_empty((0,), dtype=input_ids.dtype)
|
||||
audio_nums = torch.sum(input_ids == audio_start_token_id)
|
||||
image_nums = (vision_tokens == image_token_id).sum()
|
||||
video_nums = (
|
||||
(vision_tokens == audio_start_token_id).sum()
|
||||
if use_audio_in_video
|
||||
else (vision_tokens == video_token_id).sum()
|
||||
)
|
||||
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
llm_pos_ids_list: list[np.ndarray] = []
|
||||
st = 0
|
||||
image_idx = 0
|
||||
video_idx = 0
|
||||
audio_idx = 0
|
||||
remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums # noqa: E501
|
||||
multimodal_nums = (
|
||||
image_nums + audio_nums
|
||||
if use_audio_in_video
|
||||
else image_nums + video_nums + audio_nums
|
||||
) # noqa: E501
|
||||
|
||||
for _ in range(multimodal_nums):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
if (image_token_id in input_tokens or video_token_id in input_tokens) and (
|
||||
remain_videos > 0 or remain_images > 0
|
||||
):
|
||||
ed_vision_start = input_tokens.index(vision_start_token_id, st)
|
||||
else:
|
||||
ed_vision_start = len(input_tokens) + 1
|
||||
if audio_token_id in input_tokens and remain_audios > 0:
|
||||
ed_audio_start = input_tokens.index(audio_start_token_id, st)
|
||||
else:
|
||||
ed_audio_start = len(input_tokens) + 1
|
||||
min_ed = min(ed_vision_start, ed_audio_start)
|
||||
for offset, modality, data in self.iter_mm_features(mm_features):
|
||||
text_len = offset - st
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
|
||||
|
||||
if min_ed == ed_audio_start:
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
if text_len > 0:
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths[audio_idx]
|
||||
)
|
||||
llm_pos_ids = (
|
||||
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + audio_len + eos_len
|
||||
audio_idx += 1
|
||||
remain_audios -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == image_token_id
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = torch.arange(grid_t) * position_id_per_seconds
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + image_len + eos_len
|
||||
image_idx += 1
|
||||
remain_images -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and input_ids[ed_vision_start + 1] == video_token_id
|
||||
and not use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st += text_len + bos_len + video_len + eos_len
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
elif (
|
||||
min_ed == ed_vision_start
|
||||
and ed_vision_start + 1 == ed_audio_start
|
||||
and use_audio_in_video
|
||||
):
|
||||
text_len = min_ed - st
|
||||
if text_len != 0:
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long)
|
||||
.view(1, -1)
|
||||
.expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
bos_len = 1
|
||||
bos_block = (
|
||||
torch.arange(bos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
llm_pos_ids_list.append(bos_block)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
audio_len = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths[audio_idx]
|
||||
)
|
||||
audio_llm_pos_ids = (
|
||||
torch.arange(audio_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* float(second_per_grids[video_idx].item())
|
||||
* position_id_per_seconds
|
||||
)
|
||||
video_llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
video_data_index, audio_data_index = 0, 0
|
||||
while (
|
||||
video_data_index < video_llm_pos_ids.shape[-1]
|
||||
and audio_data_index < audio_llm_pos_ids.shape[-1]
|
||||
):
|
||||
if (
|
||||
video_llm_pos_ids[0][video_data_index]
|
||||
<= audio_llm_pos_ids[0][audio_data_index]
|
||||
):
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_data_index + 1
|
||||
]
|
||||
)
|
||||
video_data_index += 1
|
||||
else:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_data_index + 1
|
||||
]
|
||||
)
|
||||
audio_data_index += 1
|
||||
if video_data_index < video_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
video_llm_pos_ids[
|
||||
:, video_data_index : video_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
if audio_data_index < audio_llm_pos_ids.shape[-1]:
|
||||
llm_pos_ids_list.append(
|
||||
audio_llm_pos_ids[
|
||||
:, audio_data_index : audio_llm_pos_ids.shape[-1]
|
||||
]
|
||||
)
|
||||
video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
eos_len = 1
|
||||
eos_block = (
|
||||
torch.arange(eos_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
llm_pos_ids_list.append(eos_block)
|
||||
st += text_len + bos_len * 2 + audio_len + video_len + eos_len * 2 # noqa: E501
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
remain_videos -= 1
|
||||
remain_audios -= 1
|
||||
st_idx += text_len
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if llm_pos_ids_list else 0
|
||||
text_len = len(input_tokens) - st
|
||||
bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
|
||||
llm_pos_ids_list.append(bos_pos)
|
||||
st_idx += 1
|
||||
|
||||
if modality == "audio":
|
||||
audio_tokens = self._compute_audio_token_count(
|
||||
data["audio_feature_length"]
|
||||
)
|
||||
audio_pos = (
|
||||
np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
|
||||
)
|
||||
llm_pos_ids_list.append(audio_pos)
|
||||
st_idx = int(audio_pos.max()) + 1
|
||||
|
||||
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
|
||||
llm_pos_ids_list.append(eos_pos)
|
||||
st = offset + 1 + audio_tokens + 1
|
||||
|
||||
elif modality == "image":
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
|
||||
grid_indices = np.indices((grid_t, grid_h, grid_w))
|
||||
if t_factor != 1.0:
|
||||
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
|
||||
|
||||
image_len = grid_t * grid_h * grid_w
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1
|
||||
|
||||
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
|
||||
llm_pos_ids_list.append(eos_pos)
|
||||
st = offset + 1 + image_len + 1
|
||||
|
||||
elif modality == "video":
|
||||
grid_t = data["grid_t"]
|
||||
grid_h = data["grid_h"]
|
||||
grid_w = data["grid_w"]
|
||||
t_factor = data["t_factor"]
|
||||
|
||||
if not data["use_audio_in_video"]:
|
||||
grid_indices = np.indices((grid_t, grid_h, grid_w))
|
||||
if t_factor != 1.0:
|
||||
grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
|
||||
llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)
|
||||
|
||||
video_len = grid_t * grid_h * grid_w
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1
|
||||
|
||||
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
|
||||
llm_pos_ids_list.append(eos_pos)
|
||||
st = offset + 1 + video_len + 1
|
||||
else:
|
||||
audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
|
||||
llm_pos_ids_list.append(audio_bos_pos)
|
||||
|
||||
pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
|
||||
llm_pos_ids_list.append(pos_ids)
|
||||
st_idx = int(pos_ids.max()) + 1
|
||||
|
||||
eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
|
||||
llm_pos_ids_list.append(eos_pos)
|
||||
llm_pos_ids_list.append(eos_pos)
|
||||
|
||||
video_len = grid_t * grid_h * grid_w
|
||||
audio_len = self._compute_audio_token_count(
|
||||
data["audio_feature_length"]
|
||||
)
|
||||
st = offset + 2 + video_len + audio_len + 2
|
||||
|
||||
if st < seq_len:
|
||||
st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
|
||||
text_len = seq_len - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len, dtype=torch.long).view(1, -1).expand(3, -1)
|
||||
+ st_idx
|
||||
np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
||||
if llm_positions.shape[1] != seq_len:
|
||||
raise RuntimeError("Position ids length mismatch with input ids length")
|
||||
|
||||
mrope_position_delta = llm_positions.max() + 1 - seq_len
|
||||
return llm_positions, mrope_position_delta
|
||||
mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
|
||||
return torch.from_numpy(llm_positions), mrope_position_delta
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user