diff --git a/tests/multimodal/media/test_video.py b/tests/multimodal/media/test_video.py index a1223ebc0..73283ba8c 100644 --- a/tests/multimodal/media/test_video.py +++ b/tests/multimodal/media/test_video.py @@ -239,6 +239,17 @@ def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch assert metadata_missing["video_backend"] == "test_video_backend_override_2" +def _make_jpeg_b64_frames(n: int, width: int = 8, height: int = 8) -> list[str]: + """Return *n* tiny base64-encoded JPEG frames.""" + frames: list[str] = [] + for i in range(n): + img = Image.new("RGB", (width, height), color=(i % 256, 0, 0)) + buf = io.BytesIO() + img.save(buf, format="JPEG") + frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii")) + return frames + + def test_load_base64_jpeg_returns_metadata(): """Regression test: load_base64 with video/jpeg must return metadata. @@ -248,16 +259,8 @@ def test_load_base64_jpeg_returns_metadata(): """ num_test_frames = 3 - frame_width, frame_height = 8, 8 - - # Build a few tiny JPEG frames and base64-encode them - b64_frames = [] - for i in range(num_test_frames): - img = Image.new("RGB", (frame_width, frame_height), color=(i * 80, 0, 0)) - buf = io.BytesIO() - img.save(buf, format="JPEG") - b64_frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii")) + b64_frames = _make_jpeg_b64_frames(num_test_frames) data = ",".join(b64_frames) imageio = ImageMediaIO() @@ -287,3 +290,52 @@ def test_load_base64_jpeg_returns_metadata(): # Default fps=1 → duration == num_frames assert metadata["fps"] == 1.0 assert metadata["duration"] == float(num_test_frames) + + +def test_load_base64_jpeg_enforces_num_frames_limit(): + """Frames beyond num_frames must be truncated in the video/jpeg path. + + Without the limit an attacker can send thousands of base64 JPEG frames + in a single request and exhaust server memory (OOM). + """ + num_frames_limit = 4 + sent_frames = 20 + + b64_frames = _make_jpeg_b64_frames(sent_frames) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=num_frames_limit) + frames, metadata = videoio.load_base64("video/jpeg", data) + + assert frames.shape[0] == num_frames_limit + assert metadata["total_num_frames"] == num_frames_limit + assert metadata["frames_indices"] == list(range(num_frames_limit)) + + +def test_load_base64_jpeg_no_limit_when_num_frames_negative(): + """When num_frames is -1, all frames should be loaded without truncation.""" + sent_frames = 10 + + b64_frames = _make_jpeg_b64_frames(sent_frames) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=-1) + frames, metadata = videoio.load_base64("video/jpeg", data) + + assert frames.shape[0] == sent_frames + assert metadata["total_num_frames"] == sent_frames + assert metadata["frames_indices"] == list(range(sent_frames)) + + +def test_load_base64_jpeg_raises_on_zero_num_frames(): + """num_frames=0 is invalid and should raise ValueError.""" + b64_frames = _make_jpeg_b64_frames(3) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=0) + + with pytest.raises(ValueError, match="num_frames must be greater than 0 or -1"): + videoio.load_base64("video/jpeg", data) diff --git a/vllm/multimodal/media/video.py b/vllm/multimodal/media/video.py index 2790d714d..691d9444b 100644 --- a/vllm/multimodal/media/video.py +++ b/vllm/multimodal/media/video.py @@ -80,8 +80,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): "image/jpeg", ) + if self.num_frames > 0: + frame_parts = data.split(",", self.num_frames)[: self.num_frames] + elif self.num_frames == 0: + raise ValueError("num_frames must be greater than 0 or -1") + else: + frame_parts = data.split(",") + frames = np.stack( - [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")] + [np.asarray(load_frame(frame_data)) for frame_data in frame_parts] ) total = int(frames.shape[0]) fps = float(self.kwargs.get("fps", 1))