132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import pybase64
|
|
from PIL import Image
|
|
|
|
from vllm import envs
|
|
|
|
from ..video import VIDEO_LOADER_REGISTRY
|
|
from .base import MediaIO
|
|
from .image import ImageMediaIO
|
|
|
|
|
|
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
|
|
"""Configuration values can be user-provided either by --media-io-kwargs or
|
|
by the runtime API field "media_io_kwargs". Ensure proper validation and
|
|
error handling.
|
|
"""
|
|
|
|
@classmethod
|
|
def merge_kwargs(
|
|
cls,
|
|
default_kwargs: dict[str, Any] | None,
|
|
runtime_kwargs: dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
merged = super().merge_kwargs(default_kwargs, runtime_kwargs)
|
|
# fps and num_frames interact with each other, so if either is
|
|
# overridden at request time, wipe the other from defaults to
|
|
# avoid unintuitive cross-field interactions.
|
|
if runtime_kwargs:
|
|
if "num_frames" in runtime_kwargs and "fps" not in runtime_kwargs:
|
|
merged.pop("fps", None)
|
|
elif "fps" in runtime_kwargs and "num_frames" not in runtime_kwargs:
|
|
merged.pop("num_frames", None)
|
|
return merged
|
|
|
|
def __init__(
|
|
self,
|
|
image_io: ImageMediaIO,
|
|
num_frames: int = 32,
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.image_io = image_io
|
|
self.num_frames = num_frames
|
|
# `kwargs` contains custom arguments from
|
|
# --media-io-kwargs for this modality, merged with
|
|
# per-request runtime media_io_kwargs via merge_kwargs().
|
|
# They can be passed to the underlying
|
|
# media loaders (e.g. custom implementations)
|
|
# for flexible control.
|
|
|
|
# Allow per-request override of video backend via kwargs.
|
|
# This enables users to specify a different backend than the
|
|
# global VLLM_VIDEO_LOADER_BACKEND env var, e.g.:
|
|
# --media-io-kwargs '{"video": {"video_backend": "torchcodec"}}'
|
|
video_loader_backend = (
|
|
kwargs.pop("video_backend", None) or envs.VLLM_VIDEO_LOADER_BACKEND
|
|
)
|
|
self.kwargs = kwargs
|
|
self.video_loader = VIDEO_LOADER_REGISTRY.load(video_loader_backend)
|
|
|
|
def load_bytes(self, data: bytes) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
return self.video_loader.load_bytes(
|
|
data, num_frames=self.num_frames, **self.kwargs
|
|
)
|
|
|
|
def load_base64(
|
|
self, media_type: str, data: str
|
|
) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
if media_type.lower() == "video/jpeg":
|
|
load_frame = partial(
|
|
self.image_io.load_base64,
|
|
"image/jpeg",
|
|
)
|
|
|
|
if self.num_frames > 0:
|
|
frame_parts = data.split(",", self.num_frames)[: self.num_frames]
|
|
elif self.num_frames == 0:
|
|
raise ValueError("num_frames must be greater than 0 or -1")
|
|
else:
|
|
frame_parts = data.split(",")
|
|
|
|
frames = np.stack(
|
|
[np.asarray(load_frame(frame_data)) for frame_data in frame_parts]
|
|
)
|
|
total = int(frames.shape[0])
|
|
fps = float(self.kwargs.get("fps", 1))
|
|
duration = total / fps if fps > 0 else 0.0
|
|
metadata = {
|
|
"total_num_frames": total,
|
|
"fps": fps,
|
|
"duration": duration,
|
|
"video_backend": "jpeg_sequence",
|
|
"frames_indices": list(range(total)),
|
|
"do_sample_frames": False,
|
|
}
|
|
return frames, metadata
|
|
|
|
return self.load_bytes(pybase64.b64decode(data))
|
|
|
|
def load_file(self, filepath: Path) -> tuple[npt.NDArray, dict[str, Any]]:
|
|
with filepath.open("rb") as f:
|
|
data = f.read()
|
|
|
|
return self.load_bytes(data)
|
|
|
|
def encode_base64(
|
|
self,
|
|
media: npt.NDArray,
|
|
*,
|
|
video_format: str = "JPEG",
|
|
) -> str:
|
|
video = media
|
|
|
|
if video_format == "JPEG":
|
|
encode_frame = partial(
|
|
self.image_io.encode_base64,
|
|
image_format=video_format,
|
|
)
|
|
|
|
return ",".join(encode_frame(Image.fromarray(frame)) for frame in video)
|
|
|
|
msg = "Only JPEG format is supported for now."
|
|
raise NotImplementedError(msg)
|