[Refactor]: Use M-RoPE interface directly while defining model class instead of maintaining model specific M-RoPE implementation in mrope.py (#24172)
Signed-off-by: Divyansh Singhvi <divyanshsinghvi@gmail.com> Signed-off-by: dsinghvi <divyanshsinghvi@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
|
||||
Qwen2_5OmniConfig,
|
||||
@@ -45,7 +46,6 @@ from transformers.models.whisper import WhisperFeatureExtractor
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2_5_vl import (
|
||||
Qwen2_5_VisionTransformer,
|
||||
@@ -93,6 +93,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@@ -101,7 +102,9 @@ from .utils import (
|
||||
WeightsMapper,
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
split_list_into_ranges,
|
||||
)
|
||||
from .vision import get_llm_pos_ids_for_vision
|
||||
|
||||
try:
|
||||
import flash_attn
|
||||
@@ -412,6 +415,59 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
|
||||
return prompt_ids, mm_placeholders
|
||||
|
||||
@classmethod
|
||||
def omni_get_updates_use_audio_in_video(
|
||||
cls,
|
||||
thinker_config: PretrainedConfig,
|
||||
audio_len: int,
|
||||
video_grid_thw: Union[list[int], torch.Tensor],
|
||||
video_second_per_grid_t: float,
|
||||
) -> list[int]:
|
||||
"""Get video prompt updates when `use_audio_in_video` is True.
|
||||
|
||||
In this case, audio and vision update ids will be split into
|
||||
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
|
||||
|
||||
<|video_bos|><|VIDEO|><|video_eos|> =>
|
||||
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
|
||||
"""
|
||||
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
audio_start_token_id = thinker_config.audio_start_token_id
|
||||
audio_end_token_id = thinker_config.audio_end_token_id
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
grid_t = video_grid_thw[0]
|
||||
grid_h = video_grid_thw[1]
|
||||
grid_w = video_grid_thw[2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t) * video_second_per_grid_t * tokens_per_second
|
||||
).long()
|
||||
t_index_split_chunk = split_list_into_ranges(t_index, t_ntoken_per_chunk)
|
||||
|
||||
updates = [audio_start_token_id]
|
||||
added_audio_len = 0
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = (
|
||||
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
)
|
||||
updates.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||
|
||||
audio_chunk_size = min(t_ntoken_per_chunk, audio_len - added_audio_len)
|
||||
updates.extend(audio_chunk_size * [audio_token_id])
|
||||
added_audio_len += audio_chunk_size
|
||||
if added_audio_len < audio_len:
|
||||
updates.extend((audio_len - added_audio_len) * [audio_token_id])
|
||||
updates.extend([audio_end_token_id])
|
||||
|
||||
return updates
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@@ -491,7 +547,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
else:
|
||||
video_second_per_grid_t = 1.0
|
||||
|
||||
return MRotaryEmbedding.omni_get_updates_use_audio_in_video(
|
||||
return self.omni_get_updates_use_audio_in_video(
|
||||
thinker_config=thinker_config,
|
||||
audio_len=audio_num_features,
|
||||
video_grid_thw=video_grid_thw,
|
||||
@@ -808,6 +864,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
):
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@@ -929,6 +986,216 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
@classmethod
|
||||
def get_mrope_input_positions(
|
||||
cls,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||
second_per_grid_ts: Optional[list[float]] = None,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
|
||||
|
||||
Differences from MRotaryEmbedding:
|
||||
1. Add audio support (and related `audio_feature_lengths`).
|
||||
2. Add `use_audio_in_video` option to read audio from video inputs.
|
||||
In this case, audio and vision position ids will be split into
|
||||
chunks and interleaved.
|
||||
|
||||
Example:
|
||||
|
||||
(V_i are vision position ids, A_i are audio position ids)
|
||||
|
||||
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|
||||
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
|
||||
"""
|
||||
|
||||
# TODO(fyabc): refactor and share more code with
|
||||
# _vl_get_input_positions_tensor.
|
||||
|
||||
thinker_config = hf_config.thinker_config
|
||||
audio_token_id = thinker_config.audio_token_index
|
||||
image_token_id = thinker_config.image_token_index
|
||||
video_token_id = thinker_config.video_token_index
|
||||
audio_start_token_id = thinker_config.audio_start_token_id
|
||||
audio_end_token_id = thinker_config.audio_end_token_id
|
||||
vision_start_token_id = thinker_config.vision_start_token_id
|
||||
vision_end_token_id = thinker_config.vision_end_token_id
|
||||
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||
tokens_per_second = getattr(
|
||||
thinker_config.vision_config, "tokens_per_second", 25
|
||||
)
|
||||
|
||||
if isinstance(image_grid_thw, list):
|
||||
image_grid_thw = torch.tensor(image_grid_thw)
|
||||
if isinstance(video_grid_thw, list):
|
||||
video_grid_thw = torch.tensor(video_grid_thw)
|
||||
|
||||
src_item = input_tokens
|
||||
audio_seqlens = audio_feature_lengths
|
||||
if not second_per_grid_ts:
|
||||
second_per_grid_ts = [1] * video_grid_thw.shape[0]
|
||||
audio_idx = 0
|
||||
video_idx = 0
|
||||
image_idx = 0
|
||||
new_src_item: list[int] = []
|
||||
llm_pos_ids_list: list[torch.Tensor] = []
|
||||
|
||||
idx = 0
|
||||
while idx < len(src_item):
|
||||
new_src_item_len = len(new_src_item)
|
||||
start_idx = (
|
||||
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
)
|
||||
if src_item[idx] not in [audio_token_id, video_token_id, image_token_id]:
|
||||
if use_audio_in_video and idx > 0:
|
||||
if (
|
||||
src_item[idx] == vision_end_token_id
|
||||
and src_item[idx - 1] == audio_end_token_id
|
||||
):
|
||||
# processing the <|audio_eos|> before <|vision_eos|>
|
||||
start_idx -= 1
|
||||
elif (
|
||||
src_item[idx] == audio_start_token_id
|
||||
and src_item[idx - 1] == vision_start_token_id
|
||||
):
|
||||
# processing the <|audio_bos|> after <|vision_eos|>
|
||||
start_idx -= 1
|
||||
new_src_item.append(src_item[idx])
|
||||
llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
elif src_item[idx] == audio_token_id:
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
place_num = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
|
||||
new_src_item.extend([audio_token_id] * place_num)
|
||||
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
audio_idx += 1
|
||||
elif src_item[idx] == image_token_id:
|
||||
grid_t = image_grid_thw[image_idx][0]
|
||||
grid_hs = image_grid_thw[:, 1]
|
||||
grid_ws = image_grid_thw[:, 2]
|
||||
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = image_grid_thw[image_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([image_token_id] * vision_seqlen)
|
||||
image_idx += 1
|
||||
elif src_item[idx] == video_token_id and not use_audio_in_video:
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
llm_pos_ids = get_llm_pos_ids_for_vision(
|
||||
start_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
|
||||
)
|
||||
llm_pos_ids_list.append(llm_pos_ids)
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_seqlen)
|
||||
video_idx += 1
|
||||
else:
|
||||
# read audio from video
|
||||
assert audio_seqlens is not None
|
||||
audio_seqlen = audio_seqlens[audio_idx]
|
||||
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||
spatial_merge_size**2
|
||||
)
|
||||
grid_t = video_grid_thw[video_idx][0]
|
||||
grid_h = video_grid_thw[video_idx][1]
|
||||
grid_w = video_grid_thw[video_idx][2]
|
||||
grid_hs = video_grid_thw[:, 1]
|
||||
grid_ws = video_grid_thw[:, 2]
|
||||
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||
t_index = (
|
||||
torch.arange(grid_t)
|
||||
* second_per_grid_ts[video_idx]
|
||||
* tokens_per_second
|
||||
).long()
|
||||
t_index_split_chunk = split_list_into_ranges(
|
||||
t_index, t_ntoken_per_chunk
|
||||
)
|
||||
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
|
||||
pure_audio_len = place_num - 2
|
||||
added_audio_len = 0
|
||||
audio_llm_pos_ids_list: list[torch.Tensor] = []
|
||||
for t_chunk in t_index_split_chunk:
|
||||
vision_ntoken_per_chunk = (
|
||||
len(t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||
)
|
||||
new_src_item.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||
vision_llm_pos_ids_list = get_llm_pos_ids_for_vision(
|
||||
start_idx,
|
||||
video_idx,
|
||||
spatial_merge_size,
|
||||
t_chunk,
|
||||
grid_hs,
|
||||
grid_ws,
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
|
||||
new_src_item.extend(
|
||||
min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)
|
||||
* [audio_token_id]
|
||||
)
|
||||
audio_start_idx = (
|
||||
start_idx
|
||||
if len(audio_llm_pos_ids_list) == 0
|
||||
else audio_llm_pos_ids_list[-1][0].item() + 1
|
||||
)
|
||||
if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0:
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(
|
||||
min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
).expand(3, -1)
|
||||
+ audio_start_idx
|
||||
).split(1, dim=1)
|
||||
else:
|
||||
audio_llm_pos_ids_list = []
|
||||
added_audio_len += min(
|
||||
t_ntoken_per_chunk, pure_audio_len - added_audio_len
|
||||
)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
if added_audio_len < pure_audio_len:
|
||||
new_src_item.extend(
|
||||
(pure_audio_len - added_audio_len) * [audio_token_id]
|
||||
)
|
||||
audio_llm_pos_ids_list = (
|
||||
torch.arange(pure_audio_len - added_audio_len).expand(3, -1)
|
||||
+ llm_pos_ids_list[-1].max()
|
||||
+ 1
|
||||
).split(1, dim=1)
|
||||
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||
audio_idx += 1
|
||||
video_idx += 1
|
||||
# move to the next token
|
||||
idx += len(new_src_item) - new_src_item_len
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
|
||||
mrope_position_delta = (
|
||||
torch.cat(llm_pos_ids_list, dim=1).max() + 1 - len(src_item)
|
||||
)
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not mm_input_by_modality:
|
||||
|
||||
Reference in New Issue
Block a user