[BugFix] Fix the issue where image embeddings were incorrectly split.… (#23366)

Signed-off-by: bppps <bpppsaka@gmail.com>
Co-authored-by: zouyu.zzx <zouyu.zzx@alibaba-inc.com>
Co-authored-by: bppps <bpppsaka@gmail.com>
This commit is contained in:
bppps
2025-08-23 00:56:46 +08:00
committed by GitHub
parent 88491c1b6b
commit 424fb7a5d2
3 changed files with 99 additions and 52 deletions

View File

@@ -25,7 +25,7 @@
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
@@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__)
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
torch.empty((0, )))
def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
MultiModalFieldConfig]]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,
torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
torch.empty((0, )))
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_pixel_grid_sizes = image_grid_thw.prod(-1)
image_embed_grid_sizes = (image_pixel_grid_sizes //
spatial_merge_size // spatial_merge_size)
num_videos = len(video_grid_sizes)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
spatial_merge_size)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
)
num_videos = len(video_grid_sizes)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_pixel_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared(
"video", num_videos),
)
return _qwen2_5_omni_thinker_field_config
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
def __init__(self, spatial_merge_size: int, *args, **kwargs):
self._spatial_merge_size = spatial_merge_size
super().__init__(self._spatial_merge_size, *args, **kwargs)
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
@@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
required_fields={
"input_audio_features", "audio_feature_lengths"
},
fields_factory=_qwen2_5_omni_thinker_field_config,
fields_factory=create_qwen2_5_omni_thinker_field_factory(
self._spatial_merge_size),
)
return super()._parse_audio_data(data)
@@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2_5OmniThinkerMultiModalDataParser(
spatial_merge_size=self.info.get_hf_config(
).vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
@@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen2_5_omni_thinker_field_config(hf_inputs)
return create_qwen2_5_omni_thinker_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)
def _maybe_apply_prompt_updates(
self,