[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,23 +1,32 @@
|
||||
from functools import lru_cache
|
||||
import base64
|
||||
from functools import lru_cache, partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from PIL import Image
|
||||
|
||||
from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import get_video_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils import PlaceholderModule, is_list_of
|
||||
|
||||
from .base import MultiModalData
|
||||
from .image import ImagePlugin
|
||||
from .base import MediaIO, MultiModalData
|
||||
from .image import ImageMediaIO, ImagePlugin
|
||||
from .inputs import MultiModalKwargs, VideoItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
decord = PlaceholderModule("decord") # type: ignore[assignment]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
cached_get_video_processor = lru_cache(get_video_processor)
|
||||
@@ -107,3 +116,73 @@ def sample_frames_from_video(frames: npt.NDArray,
|
||||
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
|
||||
sampled_frames = frames[frame_indices, ...]
|
||||
return sampled_frames
|
||||
|
||||
|
||||
class VideoMediaIO(MediaIO[npt.NDArray]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_io: ImageMediaIO,
|
||||
*,
|
||||
num_frames: int = 32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.image_io = image_io
|
||||
self.num_frames = num_frames
|
||||
|
||||
def load_bytes(self, data: bytes) -> npt.NDArray:
|
||||
vr = decord.VideoReader(BytesIO(data), num_threads=1)
|
||||
total_frame_num = len(vr)
|
||||
|
||||
num_frames = self.num_frames
|
||||
if total_frame_num > num_frames:
|
||||
uniform_sampled_frames = np.linspace(0,
|
||||
total_frame_num - 1,
|
||||
num_frames,
|
||||
dtype=int)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
else:
|
||||
frame_idx = list(range(0, total_frame_num))
|
||||
|
||||
return vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> npt.NDArray:
|
||||
if media_type.lower() == "video/jpeg":
|
||||
load_frame = partial(
|
||||
self.image_io.load_base64,
|
||||
"image/jpeg",
|
||||
)
|
||||
|
||||
return np.stack([
|
||||
np.array(load_frame(frame_data))
|
||||
for frame_data in data.split(",")
|
||||
])
|
||||
|
||||
return self.load_bytes(base64.b64decode(data))
|
||||
|
||||
def load_file(self, filepath: Path) -> npt.NDArray:
|
||||
with filepath.open("rb") as f:
|
||||
data = f.read()
|
||||
|
||||
return self.load_bytes(data)
|
||||
|
||||
def encode_base64(
|
||||
self,
|
||||
media: npt.NDArray,
|
||||
*,
|
||||
video_format: str = "JPEG",
|
||||
) -> str:
|
||||
video = media
|
||||
|
||||
if video_format == "JPEG":
|
||||
encode_frame = partial(
|
||||
self.image_io.encode_base64,
|
||||
image_format=video_format,
|
||||
)
|
||||
|
||||
return ",".join(
|
||||
encode_frame(Image.fromarray(frame)) for frame in video)
|
||||
|
||||
msg = "Only JPEG format is supported for now."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
Reference in New Issue
Block a user