[Bugfix] Fix GLM4.1V multimodal processor with compatability for Transformers v4.56 (#24822)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-15 20:45:06 +08:00
committed by GitHub
parent 72c99f2a75
commit 0e219cd50b
6 changed files with 118 additions and 70 deletions

View File

@@ -36,7 +36,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from packaging.version import Version
from transformers import BatchFeature
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig
from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor, smart_resize)
@@ -1001,28 +1003,32 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
max_frame_idx = meta_frames - 1
duration = metadata.get("duration",
round(max_frame_idx / video_fps) + 1)
if duration <= video_processor.max_duration:
n = int(math.floor(duration * video_processor.fps))
frame_indices = [
min(
max_frame_idx,
int(math.ceil(i * video_fps / video_processor.fps)),
) for i in range(n)
]
do_sample_frames = metadata["do_sample_frames"]
if not do_sample_frames:
frame_indices = metadata["frames_indices"]
else:
num_samples = int(video_processor.max_duration *
video_processor.fps)
if num_samples >= meta_frames:
frame_indices = list(range(meta_frames))
else:
target_seconds = np.linspace(0,
duration,
num_samples,
endpoint=True)
if duration <= video_processor.max_duration:
n = int(math.floor(duration * video_processor.fps))
frame_indices = [
min(max_frame_idx, int(math.ceil(t * video_fps)))
for t in target_seconds
min(
max_frame_idx,
int(math.ceil(i * video_fps / video_processor.fps)),
) for i in range(n)
]
else:
num_samples = int(video_processor.max_duration *
video_processor.fps)
if num_samples >= meta_frames:
frame_indices = list(range(meta_frames))
else:
target_seconds = np.linspace(0,
duration,
num_samples,
endpoint=True)
frame_indices = [
min(max_frame_idx, int(math.ceil(t * video_fps)))
for t in target_seconds
]
seen, uniq = set(), []
for idx in frame_indices:
@@ -1139,7 +1145,9 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
"fps": 2.0,
"duration": num_frames / 2.0,
"total_num_frames": num_frames,
"frames_indices": [i for i in range(num_frames)],
"video_backend": "opencv",
"do_sample_frames": False,
}
video_item = (video.copy(), video_metadata)
video_items.append(video_item)
@@ -1172,34 +1180,37 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
for item in mm_data.pop("videos", []):
video_array, metadata = item
if metadata["video_backend"] == "opencv_dynamic":
mm_kwargs["do_sample_frames"] = False
elif metadata["total_num_frames"] != len(video_array):
logger.warning(
"Total frames in metadata "
"(%s) does not match the length of "
"video array %s. This can "
"be because the video is resampled "
"in advance. This may cause "
"a divergence with HF implementation.",
metadata["total_num_frames"],
len(video_array),
)
metadata["total_num_frames"] = len(video_array)
# don't update mm_kwargs inplace
video_mm_kwargs = dict(**mm_kwargs)
video_mm_kwargs["do_sample_frames"] = metadata.get(
"do_sample_frames", True)
video_mm_data = dict()
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[VideoMetadata(**metadata)]]
# backward compatibility for Transformers 4.55
unuse_metadata = ["do_sample_frames"]
if not hasattr(
VideoMetadata,
"frames_indices") and "frames_indices" in metadata:
unuse_metadata.append("frames_indices")
video_mm_data["video_metadata"] = [[
VideoMetadata(
**{
k: metadata[k]
for k in metadata if k not in unuse_metadata
})
]]
video_outputs = super()._call_hf_processor(
prompt="<|begin_of_video|><|video|><|end_of_video|>",
mm_data=video_mm_data,
mm_kwargs=mm_kwargs,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "do_sample_frames" in mm_kwargs and not mm_kwargs[
"do_sample_frames"]:
if not video_mm_kwargs["do_sample_frames"] and Version(
TRANSFORMERS_VERSION) < Version("4.56.0"):
# Transformers v4.55 has incorrect timestamps issue for
# skip sampling. We construct the placeholder manually to
# get placeholders with correct timestamps.
@@ -1218,6 +1229,7 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
prompt = prompt.replace(
"<|begin_of_video|><|video|><|end_of_video|>",
video_placeholder,
1,
)
video_grid_thw_lst.append(video_outputs["video_grid_thw"])