[Refactor] Modular video loader backend refactoring (#35202)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -7,7 +7,13 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader
|
||||
from vllm.assets.base import get_vllm_public_assets
|
||||
from vllm.multimodal.video import (
|
||||
VIDEO_LOADER_REGISTRY,
|
||||
VideoLoader,
|
||||
)
|
||||
|
||||
from .utils import create_video_from_image
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
@@ -291,3 +297,76 @@ def test_video_recovery_dynamic_backend(monkeypatch: pytest.MonkeyPatch):
|
||||
f"Got {frames_with_recovery.shape[0]} with recovery vs "
|
||||
f"{frames_no_recovery.shape[0]} without"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_video_path(tmp_path):
|
||||
image_path = get_vllm_public_assets(
|
||||
filename="stop_sign.jpg", s3_prefix="vision_model_images"
|
||||
)
|
||||
|
||||
video_path = tmp_path / "test_RGB_video.mp4"
|
||||
create_video_from_image(str(image_path), str(video_path), num_frames=1800, fps=30)
|
||||
return video_path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend, kwargs, expected_num_frames",
|
||||
[
|
||||
# opencv: num_frames directly controls count
|
||||
pytest.param("opencv", {"num_frames": 32}, 32, id="opencv-num_frames"),
|
||||
pytest.param("opencv", {"fps": 2}, 120, id="opencv-fps"),
|
||||
pytest.param(
|
||||
"opencv",
|
||||
{"num_frames": 500, "fps": 2},
|
||||
120,
|
||||
id="opencv-num_frames_wins_fps",
|
||||
),
|
||||
pytest.param(
|
||||
"opencv_dynamic",
|
||||
{"fps": 1, "max_duration": 60},
|
||||
60,
|
||||
id="opencv_dynamic-within_max_duration",
|
||||
),
|
||||
pytest.param(
|
||||
"opencv_dynamic",
|
||||
{"fps": 2, "max_duration": 30},
|
||||
60,
|
||||
id="opencv_dynamic-exceeds_max_duration",
|
||||
),
|
||||
pytest.param(
|
||||
"openpangu", {"num_frames": 32, "fps": -1}, 32, id="openpangu-num_frames"
|
||||
),
|
||||
pytest.param(
|
||||
"molmo2",
|
||||
{"num_frames": 32, "frame_sample_mode": "uniform_last_frame"},
|
||||
32,
|
||||
id="molmo2-uniform_last_frame",
|
||||
),
|
||||
pytest.param(
|
||||
"molmo2",
|
||||
{"fps": 2, "frame_sample_mode": "fps"},
|
||||
119,
|
||||
id="molmo2-fps",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_video_loader_frames_sampling(
|
||||
dummy_video_path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
backend: str,
|
||||
kwargs: dict,
|
||||
expected_num_frames: int,
|
||||
):
|
||||
"""Test video loader frames sampling functionality."""
|
||||
monkeypatch.setenv("VLLM_VIDEO_LOADER_BACKEND", backend)
|
||||
loader = VIDEO_LOADER_REGISTRY.load(backend)
|
||||
|
||||
with open(dummy_video_path, "rb") as f:
|
||||
long_video_bytes = f.read()
|
||||
|
||||
frames, _ = loader.load_bytes(long_video_bytes, **kwargs)
|
||||
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[3] == 3 # RGB
|
||||
assert frames.shape[0] == expected_num_frames
|
||||
|
||||
@@ -3,17 +3,23 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import cv2
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
from vllm.utils.registry import ExtensionManager
|
||||
|
||||
try:
|
||||
import cv2
|
||||
import cv2.videoio_registry as vr
|
||||
except ImportError:
|
||||
cv2 = PlaceholderModule("cv2")
|
||||
vr = PlaceholderModule("cv2").placeholder_attr("videoio_registry")
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -23,8 +29,6 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray:
|
||||
resized_frames = np.empty(
|
||||
(num_frames, new_height, new_width, channels), dtype=frames.dtype
|
||||
)
|
||||
# lazy import cv2 to avoid bothering users who only use text models
|
||||
import cv2
|
||||
|
||||
for i, frame in enumerate(frames):
|
||||
resized_frame = cv2.resize(frame, (new_width, new_height))
|
||||
@@ -50,16 +54,100 @@ def sample_frames_from_video(frames: npt.NDArray, num_frames: int) -> npt.NDArra
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoTargetMetadata(NamedTuple):
|
||||
"""Metadata represents target video."""
|
||||
|
||||
num_frames: int
|
||||
fps: float
|
||||
max_duration: float
|
||||
|
||||
|
||||
class VideoSourceMetadata(NamedTuple):
|
||||
"""Metadata represents source video."""
|
||||
|
||||
total_frames_num: int
|
||||
original_fps: float
|
||||
duration: float
|
||||
|
||||
|
||||
class VideoLoader:
|
||||
@classmethod
|
||||
def compute_frames_index_to_sample(
|
||||
cls,
|
||||
source: VideoSourceMetadata,
|
||||
target: VideoTargetMetadata,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
"""Return the list of frame indices to sample from the video."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load_bytes(
|
||||
cls, data: bytes, num_frames: int = -1, **kwargs
|
||||
cls,
|
||||
data: bytes,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""Load video frames from bytes and return (frames_array, metadata_dict)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def create_hf_metadata(
|
||||
cls,
|
||||
source: VideoSourceMetadata,
|
||||
valid_frame_indices: list[int],
|
||||
video_backend: str,
|
||||
):
|
||||
return {
|
||||
"total_num_frames": source.total_frames_num,
|
||||
"fps": source.original_fps,
|
||||
"duration": source.duration,
|
||||
"video_backend": video_backend,
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": len(valid_frame_indices) == source.total_frames_num,
|
||||
}
|
||||
|
||||
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
|
||||
|
||||
class OpenCVVideoBackendMixin:
|
||||
@staticmethod
|
||||
def get_cv2_video_api():
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
@classmethod
|
||||
def open_video_capture(cls, data: bytes) -> "cv2.VideoCapture":
|
||||
backend = cls.get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
return cap
|
||||
|
||||
@staticmethod
|
||||
def get_video_metadata(cap: "cv2.VideoCapture") -> VideoSourceMetadata:
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
return VideoSourceMetadata(
|
||||
total_frames_num=total_frames_num,
|
||||
original_fps=original_fps,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _can_use_for_recovery(
|
||||
cls,
|
||||
idx: int,
|
||||
failed_frames: list[int],
|
||||
next_target_map: dict[int, int],
|
||||
@@ -72,8 +160,9 @@ class VideoLoader:
|
||||
limit = next_target_map.get(oldest_failed, total_frames)
|
||||
return idx < limit
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _read_frames_with_recovery(
|
||||
cls,
|
||||
cap: "cv2.VideoCapture",
|
||||
frame_indices: list[int],
|
||||
total_frames: int,
|
||||
@@ -95,8 +184,6 @@ class VideoLoader:
|
||||
- valid_frame_indices: List of frame indices that were loaded
|
||||
- recovered_map: Dict mapping recovered_idx -> source_idx
|
||||
"""
|
||||
import cv2
|
||||
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
|
||||
@@ -135,7 +222,7 @@ class VideoLoader:
|
||||
continue
|
||||
|
||||
# Check if we should retrieve: target frame OR can recover a failed one
|
||||
can_recover = VideoLoader._can_use_for_recovery(
|
||||
can_recover = cls._can_use_for_recovery(
|
||||
idx, failed_frames_idx, next_target_map, total_frames
|
||||
)
|
||||
|
||||
@@ -179,15 +266,14 @@ class VideoLoader:
|
||||
|
||||
return frames, valid_frame_indices, recovered_map
|
||||
|
||||
@staticmethod
|
||||
def _read_frames(
|
||||
@classmethod
|
||||
def _read_frames_no_recovery(
|
||||
cls,
|
||||
cap,
|
||||
frame_indices: set[int],
|
||||
num_expected_frames: int,
|
||||
max_frame_idx: int,
|
||||
) -> tuple[npt.NDArray, int, list[int]]:
|
||||
import cv2
|
||||
|
||||
) -> tuple[npt.NDArray, list[int]]:
|
||||
num_expected_frames = len(frame_indices)
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
frames = np.empty((num_expected_frames, height, width, 3), dtype=np.uint8)
|
||||
@@ -229,28 +315,77 @@ class VideoLoader:
|
||||
valid_num_frames,
|
||||
)
|
||||
|
||||
return frames[:valid_num_frames], valid_num_frames, valid_frame_indices
|
||||
return frames[:valid_num_frames], valid_frame_indices
|
||||
|
||||
@classmethod
|
||||
def read_frames(
|
||||
cls,
|
||||
cap: "cv2.VideoCapture",
|
||||
frame_idx: list[int],
|
||||
total_frames_num: int,
|
||||
*,
|
||||
frame_recovery: bool = False,
|
||||
) -> tuple[npt.NDArray, list[int]]:
|
||||
if frame_recovery:
|
||||
num_frames_to_sample = len(frame_idx)
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frame_idx, total_frames_num
|
||||
)
|
||||
|
||||
VIDEO_LOADER_REGISTRY = ExtensionManager()
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
else:
|
||||
frame_idx_set = set(frame_idx)
|
||||
num_frames_to_sample = len(frame_idx_set)
|
||||
frames, valid_frame_indices = cls._read_frames_no_recovery(
|
||||
cap, frame_idx_set, max(frame_idx)
|
||||
)
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
if valid_num_frames < num_frames_to_sample:
|
||||
logger.warning(
|
||||
"Video loading completed with %d broken/unreadable frames. "
|
||||
"Expected to sample %d frames but only loaded %d frames.",
|
||||
num_frames_to_sample - valid_num_frames,
|
||||
num_frames_to_sample,
|
||||
valid_num_frames,
|
||||
)
|
||||
return frames, valid_frame_indices
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv")
|
||||
class OpenCVVideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
class OpenCVVideoBackend(VideoLoader, OpenCVVideoBackendMixin):
|
||||
@classmethod
|
||||
def compute_frames_index_to_sample(
|
||||
cls,
|
||||
source: VideoSourceMetadata,
|
||||
target: VideoTargetMetadata,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
total_frames_num = source.total_frames_num
|
||||
duration = source.duration
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
num_frames = target.num_frames
|
||||
fps = target.fps
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
return frame_idx
|
||||
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
@@ -275,108 +410,54 @@ class OpenCVVideoBackend(VideoLoader):
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
cap = cls.open_video_capture(data)
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
source = OpenCVVideoBackendMixin.get_video_metadata(cap)
|
||||
target = VideoTargetMetadata(
|
||||
num_frames=num_frames,
|
||||
fps=fps,
|
||||
max_duration=max_duration,
|
||||
)
|
||||
|
||||
# resample video to target num_frames and fps
|
||||
# - the minimum of the two will be used
|
||||
num_frames_to_sample = total_frames_num
|
||||
if num_frames > 0:
|
||||
num_frames_to_sample = min(num_frames, total_frames_num)
|
||||
if fps > 0:
|
||||
num_frames_to_sample = min(num_frames_to_sample, math.floor(duration * fps))
|
||||
num_frames_to_sample = max(1, num_frames_to_sample) # at least one sample
|
||||
frame_idx = cls.compute_frames_index_to_sample(
|
||||
source=source,
|
||||
target=target,
|
||||
)
|
||||
|
||||
if num_frames_to_sample == total_frames_num:
|
||||
frame_idx = list(range(0, num_frames_to_sample))
|
||||
else:
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frames_num - 1, num_frames_to_sample, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
frames, valid_frame_indices = cls.read_frames(
|
||||
cap,
|
||||
frame_idx,
|
||||
total_frames_num=source.total_frames_num,
|
||||
frame_recovery=frame_recovery,
|
||||
)
|
||||
|
||||
if frame_recovery:
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frame_idx, total_frames_num
|
||||
)
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
else:
|
||||
frame_idx_set = set(frame_idx)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_idx_set, num_frames_to_sample, max(frame_idx)
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
# NOTE(Isotr0py): For models like Qwen3-VL/GLM4.5V, this metadata
|
||||
# can cause incorrect timestamp calculation without num_frames=-1.
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": valid_frame_indices,
|
||||
# extra field used to control hf processor's video
|
||||
# sampling behavior
|
||||
"do_sample_frames": valid_num_frames == total_frames_num,
|
||||
}
|
||||
metadata = cls.create_hf_metadata(
|
||||
source=source,
|
||||
video_backend="opencv",
|
||||
valid_frame_indices=valid_frame_indices,
|
||||
)
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("opencv_dynamic")
|
||||
class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
class OpenCVDynamicVideoBackend(VideoLoader, OpenCVVideoBackendMixin):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
def compute_frames_index_to_sample(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
source: VideoSourceMetadata,
|
||||
target: VideoTargetMetadata,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 2)
|
||||
max_duration: Maximum video duration to process (default: 300s)
|
||||
frame_recovery: Enable forward-scan recovery for failed frames
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
# resample video to target num_frames
|
||||
max_frame_idx = total_frames_num - 1
|
||||
duration = duration or round(max_frame_idx / original_fps) + 1
|
||||
) -> list[int]:
|
||||
total_frames_num = source.total_frames_num
|
||||
duration = source.duration
|
||||
original_fps = source.original_fps
|
||||
max_duration = target.max_duration
|
||||
fps = target.fps
|
||||
|
||||
max_frame_idx = source.total_frames_num - 1
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.4/src/transformers/models/glm4v/video_processing_glm4v.py#L103-L140
|
||||
frame_indices_list: list[int]
|
||||
@@ -400,54 +481,75 @@ class OpenCVDynamicVideoBackend(OpenCVVideoBackend):
|
||||
for t in target_seconds
|
||||
}
|
||||
)
|
||||
return frame_indices_list
|
||||
|
||||
if frame_recovery:
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frame_indices_list, total_frames_num
|
||||
)
|
||||
valid_num_frames = len(valid_frame_indices)
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
else:
|
||||
frame_indices_set = set(frame_indices_list)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_indices_set, len(frame_indices_list), total_frames_num - 1
|
||||
)
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 2)
|
||||
max_duration: Maximum video duration to process (default: 300s)
|
||||
frame_recovery: Enable forward-scan recovery for failed frames
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv_dynamic",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
cap = cls.open_video_capture(data)
|
||||
|
||||
orig_source = OpenCVVideoBackendMixin.get_video_metadata(cap)
|
||||
max_frame_idx = orig_source.total_frames_num - 1
|
||||
duration = (
|
||||
orig_source.duration or round(max_frame_idx / orig_source.original_fps) + 1
|
||||
)
|
||||
|
||||
# recompute source metadata with adjusted duration to ensure correct
|
||||
# sampling indices computation
|
||||
source = VideoSourceMetadata(
|
||||
total_frames_num=orig_source.total_frames_num,
|
||||
original_fps=orig_source.original_fps,
|
||||
duration=duration,
|
||||
)
|
||||
target = VideoTargetMetadata(
|
||||
num_frames=num_frames,
|
||||
fps=fps,
|
||||
max_duration=max_duration,
|
||||
)
|
||||
|
||||
frame_indices_list = cls.compute_frames_index_to_sample(
|
||||
source=source,
|
||||
target=target,
|
||||
)
|
||||
|
||||
frames, valid_frame_indices = cls.read_frames(
|
||||
cap,
|
||||
frame_indices_list,
|
||||
total_frames_num=source.total_frames_num,
|
||||
frame_recovery=frame_recovery,
|
||||
)
|
||||
|
||||
metadata = cls.create_hf_metadata(
|
||||
source=source,
|
||||
video_backend="opencv_dynamic",
|
||||
valid_frame_indices=valid_frame_indices,
|
||||
)
|
||||
|
||||
return frames, metadata
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("molmo2")
|
||||
class Molmo2VideoBackend(VideoLoader):
|
||||
def get_cv2_video_api(self):
|
||||
import cv2.videoio_registry as vr
|
||||
|
||||
api_pref = None
|
||||
for backend in vr.getStreamBufferedBackends():
|
||||
if not vr.hasBackend(backend):
|
||||
continue
|
||||
if not vr.isBackendBuiltIn(backend):
|
||||
_, abi, api = vr.getStreamBufferedBackendPluginVersion(backend)
|
||||
if abi < 1 or (abi == 1 and api < 2):
|
||||
continue
|
||||
api_pref = backend
|
||||
break
|
||||
return api_pref
|
||||
|
||||
class Molmo2VideoBackend(VideoLoader, OpenCVVideoBackendMixin):
|
||||
@classmethod
|
||||
def get_candidate_target_fps(
|
||||
cls,
|
||||
@@ -599,16 +701,28 @@ class Molmo2VideoBackend(VideoLoader):
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
@classmethod
|
||||
def _sample_frames(
|
||||
def compute_frames_index_to_sample(
|
||||
cls,
|
||||
total_num_frames: int,
|
||||
video_fps: float,
|
||||
duration: float,
|
||||
frame_sample_mode: str,
|
||||
num_frames: int,
|
||||
max_fps: int,
|
||||
sampling_fps: int,
|
||||
) -> npt.NDArray:
|
||||
source: VideoSourceMetadata,
|
||||
target: VideoTargetMetadata,
|
||||
**kwargs,
|
||||
):
|
||||
max_fps = kwargs.get("max_fps")
|
||||
frame_sample_mode = kwargs.get("frame_sample_mode")
|
||||
if frame_sample_mode is None:
|
||||
return list(range(0, source.total_frames_num))
|
||||
|
||||
if frame_sample_mode not in {"uniform_last_frame", "fps"}:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported frame_sample_mode: {frame_sample_mode}"
|
||||
)
|
||||
|
||||
duration = source.duration
|
||||
video_fps = source.original_fps
|
||||
total_num_frames = source.total_frames_num
|
||||
num_frames = target.num_frames
|
||||
sampling_fps = target.fps
|
||||
|
||||
if frame_sample_mode == "uniform_last_frame" and max_fps is not None:
|
||||
if total_num_frames <= 2:
|
||||
indices = np.arange(total_num_frames).astype(int)
|
||||
@@ -655,10 +769,7 @@ class Molmo2VideoBackend(VideoLoader):
|
||||
num_frames,
|
||||
video_fps,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(frame_sample_mode)
|
||||
|
||||
return indices
|
||||
return indices.tolist()
|
||||
|
||||
@classmethod
|
||||
def load_bytes_opencv(
|
||||
@@ -668,63 +779,37 @@ class Molmo2VideoBackend(VideoLoader):
|
||||
num_frames: int = -1,
|
||||
max_fps: int = 2,
|
||||
sampling_fps: int = 2,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
import cv2
|
||||
cap = cls.open_video_capture(data)
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
duration = total_frames_num / original_fps if original_fps > 0 else 0
|
||||
|
||||
if frame_sample_mode is None:
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
frame_idx = list(range(0, total_frames_num))
|
||||
frame_idx_set = set(frame_idx)
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap, frame_idx_set, total_frames_num, max(frame_idx)
|
||||
)
|
||||
do_sample_frames = valid_num_frames == total_frames_num
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"do_sample_frames": do_sample_frames,
|
||||
}
|
||||
if not do_sample_frames:
|
||||
metadata["frames_indices"] = valid_frame_indices
|
||||
return frames, metadata
|
||||
|
||||
frame_idx = cls._sample_frames(
|
||||
total_frames_num,
|
||||
original_fps,
|
||||
duration,
|
||||
frame_sample_mode,
|
||||
num_frames,
|
||||
max_fps,
|
||||
sampling_fps,
|
||||
).tolist()
|
||||
|
||||
frames, valid_num_frames, valid_frame_indices = cls._read_frames(
|
||||
cap,
|
||||
set(frame_idx),
|
||||
len(frame_idx),
|
||||
total_frames_num - 1,
|
||||
source = OpenCVVideoBackendMixin.get_video_metadata(cap)
|
||||
target = VideoTargetMetadata(
|
||||
num_frames=num_frames,
|
||||
fps=sampling_fps,
|
||||
max_duration=source.duration,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": duration,
|
||||
"video_backend": "opencv",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
frame_idx = cls.compute_frames_index_to_sample(
|
||||
source=source,
|
||||
target=target,
|
||||
frame_sample_mode=frame_sample_mode,
|
||||
max_fps=max_fps,
|
||||
)
|
||||
|
||||
frames, valid_frame_indices = cls.read_frames(
|
||||
cap,
|
||||
frame_idx,
|
||||
total_frames_num=source.total_frames_num,
|
||||
frame_recovery=frame_recovery,
|
||||
)
|
||||
|
||||
metadata = cls.create_hf_metadata(
|
||||
source=source,
|
||||
video_backend="opencv",
|
||||
valid_frame_indices=valid_frame_indices,
|
||||
)
|
||||
|
||||
return frames, metadata
|
||||
|
||||
@@ -777,42 +862,19 @@ class NemotronVLVideoBackend(OpenCVVideoBackend):
|
||||
|
||||
|
||||
@VIDEO_LOADER_REGISTRY.register("openpangu")
|
||||
class OpenCVDynamicOpenPanguVideoBackend(OpenCVVideoBackend):
|
||||
class OpenCVDynamicOpenPanguVideoBackend(VideoLoader, OpenCVVideoBackendMixin):
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
def compute_frames_index_to_sample(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = 32,
|
||||
fps: int = 1,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
source: VideoSourceMetadata,
|
||||
target: VideoTargetMetadata,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
Assume that total_num_frames = 10 and fps = 1.
|
||||
The timestamp of frame 0 is 0.0.
|
||||
The timestamp of frame 1 is 1.0.…
|
||||
The timestamp of frame 9 (the last frame) should be 9.0, that is,
|
||||
(total_frames_num – 1) / original_fps.
|
||||
) -> list[int]:
|
||||
total_frames_num = source.total_frames_num
|
||||
original_fps = source.original_fps
|
||||
num_frames = target.num_frames
|
||||
fps = target.fps
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 1)
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
import cv2
|
||||
|
||||
backend = cls().get_cv2_video_api()
|
||||
cap = cv2.VideoCapture(BytesIO(data), backend, [])
|
||||
if not cap.isOpened():
|
||||
raise ValueError("Could not open video stream")
|
||||
|
||||
total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
original_fps = float(cap.get(cv2.CAP_PROP_FPS))
|
||||
# The timestamp of the rightmost frame, cannot be used to calculate frame 0.
|
||||
if total_frames_num >= 1 and original_fps > 0:
|
||||
total_duration = (total_frames_num - 1) / original_fps
|
||||
@@ -841,23 +903,59 @@ class OpenCVDynamicOpenPanguVideoBackend(OpenCVVideoBackend):
|
||||
min(total_frames_num - 1, round(t * original_fps))
|
||||
for t in sample_frame_timestamps
|
||||
]
|
||||
return frames_indices
|
||||
|
||||
frames, valid_frame_indices, recovered_map = cls._read_frames_with_recovery(
|
||||
cap, frames_indices, total_frames_num
|
||||
@classmethod
|
||||
def load_bytes(
|
||||
cls,
|
||||
data: bytes,
|
||||
num_frames: int = -1,
|
||||
fps: int = 2,
|
||||
max_duration: int = 300,
|
||||
frame_recovery: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[npt.NDArray, dict[str, Any]]:
|
||||
"""
|
||||
Load video frames with dynamic sampling based on duration.
|
||||
|
||||
Args:
|
||||
data: Raw video bytes
|
||||
num_frames: Not used in dynamic backend
|
||||
fps: Target FPS for sampling (default: 2)
|
||||
max_duration: Maximum video duration to process (default: 300s)
|
||||
frame_recovery: Enable forward-scan recovery for failed frames
|
||||
|
||||
Returns:
|
||||
Tuple of (frames_array, metadata_dict)
|
||||
"""
|
||||
cap = cls.open_video_capture(data)
|
||||
|
||||
source = OpenCVVideoBackendMixin.get_video_metadata(cap)
|
||||
|
||||
# recompute source metadata with adjusted duration to ensure correct
|
||||
# sampling indices computation
|
||||
target = VideoTargetMetadata(
|
||||
num_frames=num_frames,
|
||||
fps=fps,
|
||||
max_duration=max_duration,
|
||||
)
|
||||
|
||||
if recovered_map:
|
||||
logger.info(
|
||||
"Frame recovery: %d frames recovered using forward scan.",
|
||||
len(recovered_map),
|
||||
)
|
||||
frame_indices_list = cls.compute_frames_index_to_sample(
|
||||
source=source,
|
||||
target=target,
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"total_num_frames": total_frames_num,
|
||||
"fps": original_fps,
|
||||
"duration": total_duration,
|
||||
"video_backend": "opencv_dynamic_openpangu",
|
||||
"frames_indices": valid_frame_indices,
|
||||
"do_sample_frames": False,
|
||||
}
|
||||
frames, valid_frame_indices = cls.read_frames(
|
||||
cap,
|
||||
frame_indices_list,
|
||||
total_frames_num=source.total_frames_num,
|
||||
frame_recovery=frame_recovery,
|
||||
)
|
||||
|
||||
# Use transformers transformers.video_utils.VideoMetadata format
|
||||
metadata = cls.create_hf_metadata(
|
||||
source=source,
|
||||
video_backend="opencv_dynamic",
|
||||
valid_frame_indices=valid_frame_indices,
|
||||
)
|
||||
return frames, metadata
|
||||
|
||||
Reference in New Issue
Block a user