From 29982d48b3640bb527bffd37fb02e06deb934849 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20P=C3=A9rez=20de=20Algaba?= <124347725+jperezdealgaba@users.noreply.github.com> Date: Wed, 1 Apr 2026 12:23:45 +0200 Subject: [PATCH] (security) Enforce frame limit in VideoMediaIO (#38636) Signed-off-by: jperezde (cherry picked from commit 58ee61422169ce17e08248f8efa1e9df434fe395) --- tests/multimodal/media/test_video.py | 70 ++++++++++++++++++++++++---- vllm/multimodal/media/video.py | 9 +++- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/tests/multimodal/media/test_video.py b/tests/multimodal/media/test_video.py index a1223ebc0..73283ba8c 100644 --- a/tests/multimodal/media/test_video.py +++ b/tests/multimodal/media/test_video.py @@ -239,6 +239,17 @@ def test_video_media_io_backend_env_var_fallback(monkeypatch: pytest.MonkeyPatch assert metadata_missing["video_backend"] == "test_video_backend_override_2" +def _make_jpeg_b64_frames(n: int, width: int = 8, height: int = 8) -> list[str]: + """Return *n* tiny base64-encoded JPEG frames.""" + frames: list[str] = [] + for i in range(n): + img = Image.new("RGB", (width, height), color=(i % 256, 0, 0)) + buf = io.BytesIO() + img.save(buf, format="JPEG") + frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii")) + return frames + + def test_load_base64_jpeg_returns_metadata(): """Regression test: load_base64 with video/jpeg must return metadata. @@ -248,16 +259,8 @@ def test_load_base64_jpeg_returns_metadata(): """ num_test_frames = 3 - frame_width, frame_height = 8, 8 - - # Build a few tiny JPEG frames and base64-encode them - b64_frames = [] - for i in range(num_test_frames): - img = Image.new("RGB", (frame_width, frame_height), color=(i * 80, 0, 0)) - buf = io.BytesIO() - img.save(buf, format="JPEG") - b64_frames.append(pybase64.b64encode(buf.getvalue()).decode("ascii")) + b64_frames = _make_jpeg_b64_frames(num_test_frames) data = ",".join(b64_frames) imageio = ImageMediaIO() @@ -287,3 +290,52 @@ def test_load_base64_jpeg_returns_metadata(): # Default fps=1 → duration == num_frames assert metadata["fps"] == 1.0 assert metadata["duration"] == float(num_test_frames) + + +def test_load_base64_jpeg_enforces_num_frames_limit(): + """Frames beyond num_frames must be truncated in the video/jpeg path. + + Without the limit an attacker can send thousands of base64 JPEG frames + in a single request and exhaust server memory (OOM). + """ + num_frames_limit = 4 + sent_frames = 20 + + b64_frames = _make_jpeg_b64_frames(sent_frames) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=num_frames_limit) + frames, metadata = videoio.load_base64("video/jpeg", data) + + assert frames.shape[0] == num_frames_limit + assert metadata["total_num_frames"] == num_frames_limit + assert metadata["frames_indices"] == list(range(num_frames_limit)) + + +def test_load_base64_jpeg_no_limit_when_num_frames_negative(): + """When num_frames is -1, all frames should be loaded without truncation.""" + sent_frames = 10 + + b64_frames = _make_jpeg_b64_frames(sent_frames) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=-1) + frames, metadata = videoio.load_base64("video/jpeg", data) + + assert frames.shape[0] == sent_frames + assert metadata["total_num_frames"] == sent_frames + assert metadata["frames_indices"] == list(range(sent_frames)) + + +def test_load_base64_jpeg_raises_on_zero_num_frames(): + """num_frames=0 is invalid and should raise ValueError.""" + b64_frames = _make_jpeg_b64_frames(3) + data = ",".join(b64_frames) + + imageio = ImageMediaIO() + videoio = VideoMediaIO(imageio, num_frames=0) + + with pytest.raises(ValueError, match="num_frames must be greater than 0 or -1"): + videoio.load_base64("video/jpeg", data) diff --git a/vllm/multimodal/media/video.py b/vllm/multimodal/media/video.py index 2790d714d..691d9444b 100644 --- a/vllm/multimodal/media/video.py +++ b/vllm/multimodal/media/video.py @@ -80,8 +80,15 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): "image/jpeg", ) + if self.num_frames > 0: + frame_parts = data.split(",", self.num_frames)[: self.num_frames] + elif self.num_frames == 0: + raise ValueError("num_frames must be greater than 0 or -1") + else: + frame_parts = data.split(",") + frames = np.stack( - [np.asarray(load_frame(frame_data)) for frame_data in data.split(",")] + [np.asarray(load_frame(frame_data)) for frame_data in frame_parts] ) total = int(frames.shape[0]) fps = float(self.kwargs.get("fps", 1))