[BugFix] Fix the issue where image embeddings were incorrectly split.… (#23366)
Signed-off-by: bppps <bpppsaka@gmail.com> Co-authored-by: zouyu.zzx <zouyu.zzx@alibaba-inc.com> Co-authored-by: bppps <bpppsaka@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError):
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
|
||||
torch.empty((0, )))
|
||||
def create_qwen2_5_omni_thinker_field_factory(
|
||||
spatial_merge_size: int
|
||||
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
|
||||
MultiModalFieldConfig]]:
|
||||
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,
|
||||
torch.Tensor]):
|
||||
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
|
||||
torch.empty((0, )))
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_pixel_grid_sizes = image_grid_thw.prod(-1)
|
||||
image_embed_grid_sizes = (image_pixel_grid_sizes //
|
||||
spatial_merge_size // spatial_merge_size)
|
||||
|
||||
num_videos = len(video_grid_sizes)
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
|
||||
spatial_merge_size)
|
||||
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_feature_lengths, dim=1),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
|
||||
)
|
||||
num_videos = len(video_grid_sizes)
|
||||
|
||||
return dict(
|
||||
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_feature_lengths, dim=1),
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_pixel_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_embed_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_embed_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
use_audio_in_video=MultiModalFieldConfig.shared(
|
||||
"video", num_videos),
|
||||
)
|
||||
|
||||
return _qwen2_5_omni_thinker_field_config
|
||||
|
||||
|
||||
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
|
||||
|
||||
def __init__(self, spatial_merge_size: int, *args, **kwargs):
|
||||
self._spatial_merge_size = spatial_merge_size
|
||||
super().__init__(self._spatial_merge_size, *args, **kwargs)
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||
@@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
|
||||
required_fields={
|
||||
"input_audio_features", "audio_feature_lengths"
|
||||
},
|
||||
fields_factory=_qwen2_5_omni_thinker_field_config,
|
||||
fields_factory=create_qwen2_5_omni_thinker_field_factory(
|
||||
self._spatial_merge_size),
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
@@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen2_5OmniThinkerMultiModalDataParser(
|
||||
spatial_merge_size=self.info.get_hf_config(
|
||||
).vision_config.spatial_merge_size,
|
||||
target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def _call_hf_processor(
|
||||
@@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen2_5_omni_thinker_field_config(hf_inputs)
|
||||
return create_qwen2_5_omni_thinker_field_factory(
|
||||
self.info.get_hf_config().vision_config.spatial_merge_size)(
|
||||
hf_inputs)
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user