[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)

Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com>
Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
dsinghvi
2025-10-11 12:51:04 +05:30
committed by GitHub
parent 55392bc879
commit 727144bed1
9 changed files with 974 additions and 1051 deletions

View File

@@ -29,6 +29,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from transformers.feature_extraction_utils import BatchFeature
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
Qwen2_5OmniConfig,
@@ -45,7 +46,6 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer,
@@ -93,6 +93,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@@ -101,7 +102,9 @@ from .utils import (
WeightsMapper,
init_vllm_registered_model,
maybe_prefix,
split_list_into_ranges,
)
from .vision import get_llm_pos_ids_for_vision
try:
import flash_attn
@@ -412,6 +415,59 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
return prompt_ids, mm_placeholders
@classmethod
def omni_get_updates_use_audio_in_video(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[list[int], torch.Tensor],
video_second_per_grid_t: float,
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
In this case, audio and vision update ids will be split into
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
<|video_bos|><|VIDEO|><|video_eos|> =>
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
"""
audio_token_id = thinker_config.audio_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(
thinker_config.vision_config, "tokens_per_second", 25
)
grid_t = video_grid_thw[0]
grid_h = video_grid_thw[1]
grid_w = video_grid_thw[2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (
torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
).long()
t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk)
updates = [audio_start_token_id]
added_audio_len = 0
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = (
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
)
updates.extend([video_token_id] * vision_ntoken_per_chunk)
audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len)
updates.extend(audio_chunk_size * [audio_token_id])
added_audio_len += audio_chunk_size
if added_audio_len < audio_len:
updates.extend((audio_len - added_audio_len) * [audio_token_id])
updates.extend([audio_end_token_id])
return updates
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
@@ -491,7 +547,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else:
video_second_per_grid_t = 1.0
return MRotaryEmbedding.omni_get_updates_use_audio_in_video(
return self.omni_get_updates_use_audio_in_video(
thinker_config=thinker_config,
audio_len=audio_num_features,
video_grid_thw=video_grid_thw,
@@ -808,6 +864,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
SupportsMRoPE,
Qwen2_5OmniConditionalGenerationMixin,
):
hf_to_vllm_mapper = WeightsMapper(
@@ -929,6 +986,216 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
@classmethod
def get_mrope_input_positions(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
Differences from MRotaryEmbedding:
1. Add audio support (and related `audio_feature_lengths`).
2. Add `use_audio_in_video` option to read audio from video inputs.
In this case, audio and vision position ids will be split into
chunks and interleaved.
Example:
(V_i are vision position ids, A_i are audio position ids)
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
"""
# TODO(fyabc): refactor and share more code with
# _vl_get_input_positions_tensor.
thinker_config = hf_config.thinker_config
audio_token_id = thinker_config.audio_token_index
image_token_id = thinker_config.image_token_index
video_token_id = thinker_config.video_token_index
audio_start_token_id = thinker_config.audio_start_token_id
audio_end_token_id = thinker_config.audio_end_token_id
vision_start_token_id = thinker_config.vision_start_token_id
vision_end_token_id = thinker_config.vision_end_token_id
seconds_per_chunk = thinker_config.seconds_per_chunk
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
tokens_per_second = getattr(
thinker_config.vision_config, "tokens_per_second", 25
)
if isinstance(image_grid_thw, list):
image_grid_thw = torch.tensor(image_grid_thw)
if isinstance(video_grid_thw, list):
video_grid_thw = torch.tensor(video_grid_thw)
src_item = input_tokens
audio_seqlens = audio_feature_lengths
if not second_per_grid_ts:
second_per_grid_ts = [1] * video_grid_thw.shape[0]
audio_idx = 0
video_idx = 0
image_idx = 0
new_src_item: list[int] = []
llm_pos_ids_list: list[torch.Tensor] = []
idx = 0
while idx < len(src_item):
new_src_item_len = len(new_src_item)
start_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
if use_audio_in_video and idx > 0:
if (
src_item[idx] == vision_end_token_id
and src_item[idx - 1] == audio_end_token_id
):
# processing the <|audio_eos|> before <|vision_eos|>
start_idx -= 1
elif (
src_item[idx] == audio_start_token_id
and src_item[idx - 1] == vision_start_token_id
):
# processing the <|audio_bos|> after <|vision_eos|>
start_idx -= 1
new_src_item.append(src_item[idx])
llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
llm_pos_ids_list.append(llm_pos_ids)
elif src_item[idx] == audio_token_id:
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
new_src_item.extend([audio_token_id] * place_num)
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
llm_pos_ids_list.append(llm_pos_ids)
audio_idx += 1
elif src_item[idx] == image_token_id:
grid_t = image_grid_thw[image_idx][0]
grid_hs = image_grid_thw[:, 1]
grid_ws = image_grid_thw[:, 2]
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = image_grid_thw[image_idx].prod() // (
spatial_merge_size**2
)
new_src_item.extend([image_token_id] * vision_seqlen)
image_idx += 1
elif src_item[idx] == video_token_id and not use_audio_in_video:
grid_t = video_grid_thw[video_idx][0]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_index = (
torch.arange(grid_t)
* second_per_grid_ts[video_idx]
* tokens_per_second
).long()
llm_pos_ids = get_llm_pos_ids_for_vision(
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
)
llm_pos_ids_list.append(llm_pos_ids)
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
new_src_item.extend([video_token_id] * vision_seqlen)
video_idx += 1
else:
# read audio from video
assert audio_seqlens is not None
audio_seqlen = audio_seqlens[audio_idx]
vision_seqlen = video_grid_thw[video_idx].prod() // (
spatial_merge_size**2
)
grid_t = video_grid_thw[video_idx][0]
grid_h = video_grid_thw[video_idx][1]
grid_w = video_grid_thw[video_idx][2]
grid_hs = video_grid_thw[:, 1]
grid_ws = video_grid_thw[:, 2]
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
t_index = (
torch.arange(grid_t)
* second_per_grid_ts[video_idx]
* tokens_per_second
).long()
t_index_split_chunk = split_list_into_ranges(
t_index, t_ntoken_per_chunk
)
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
pure_audio_len = place_num - 2
added_audio_len = 0
audio_llm_pos_ids_list: list[torch.Tensor] = []
for t_chunk in t_index_split_chunk:
vision_ntoken_per_chunk = (
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
)
new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
start_idx,
video_idx,
spatial_merge_size,
t_chunk,
grid_hs,
grid_ws,
).split(1, dim=1)
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
new_src_item.extend(
min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
* [audio_token_id]
)
audio_start_idx = (
start_idx
if len(audio_llm_pos_ids_list) == 0
else audio_llm_pos_ids_list[-1][0].item() + 1
)
if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
audio_llm_pos_ids_list = (
torch.arange(
min(
t_ntoken_per_chunk, pure_audio_len - added_audio_len
)
).expand(3, -1)
+ audio_start_idx
).split(1, dim=1)
else:
audio_llm_pos_ids_list = []
added_audio_len += min(
t_ntoken_per_chunk, pure_audio_len - added_audio_len
)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
if added_audio_len < pure_audio_len:
new_src_item.extend(
(pure_audio_len - added_audio_len) * [audio_token_id]
)
audio_llm_pos_ids_list = (
torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
+ llm_pos_ids_list[-1].max()
+ 1
).split(1, dim=1)
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
audio_idx += 1
video_idx += 1
# move to the next token
idx += len(new_src_item) - new_src_item_len
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
mrope_position_delta = (
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
)
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality: