[Bugfix][Frontend] Return 400 for corrupt/truncated image inputs instead of 500 (#38253)

Signed-off-by: aliialsaeedii <ali.al-saeedi@nscale.com>
This commit is contained in:
aliialsaeedii
2026-03-30 12:26:46 +02:00
committed by GitHub
parent 3683fe6c06
commit 7e76af14fa
2 changed files with 82 additions and 6 deletions

View File

@@ -131,3 +131,77 @@ def test_image_media_io_rgba_background_color_validation():
ImageMediaIO(rgba_background_color=(0, 0, 0)) # Should not raise
ImageMediaIO(rgba_background_color=[255, 255, 255]) # Should not raise
ImageMediaIO(rgba_background_color=(128, 128, 128)) # Should not raise
def test_image_media_io_load_bytes(tmp_path):
"""Test load_bytes with valid and invalid image data."""
# Save a valid RGB image to use as source bytes
valid_image = Image.new("RGB", (8, 8), (100, 150, 200))
valid_path = tmp_path / "valid.png"
valid_image.save(valid_path)
valid_data = valid_path.read_bytes()
# Test 1: Valid image bytes load successfully and are fully decoded
image_io = ImageMediaIO()
result = image_io.load_bytes(valid_data)
# Check the returned media is a properly loaded image
assert isinstance(result.media, Image.Image)
assert result.media.size == (8, 8)
assert result.media.getpixel((0, 0)) == (100, 150, 200)
# Test 2: Garbage bytes raise ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"not an image")
# Test 3: Truncated PNG header raises ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10)
# Test 4: Real PNG truncated mid-stream raises ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(valid_data[: len(valid_data) // 2])
# Test 5: Empty bytes raise ValueError
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_bytes(b"")
def test_image_media_io_load_file(tmp_path):
"""Test load_file with valid and invalid image files."""
# Save a valid RGB image to disk
valid_image = Image.new("RGB", (4, 4), (10, 20, 30))
valid_path = tmp_path / "valid.png"
valid_image.save(valid_path)
# Test 1: Valid image file loads successfully and is fully decoded
image_io = ImageMediaIO()
result = image_io.load_file(valid_path)
# Check the returned media is a properly loaded image
assert isinstance(result.media, Image.Image)
assert result.media.size == (4, 4)
assert result.media.getpixel((0, 0)) == (10, 20, 30)
# Test 2: File with garbage content raises ValueError
bad_file = tmp_path / "bad.png"
bad_file.write_bytes(b"this is not an image")
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(bad_file)
# Test 3: File with truncated PNG header raises ValueError
truncated_file = tmp_path / "truncated.png"
truncated_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10)
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(truncated_file)
# Test 4: Real PNG file truncated mid-stream raises ValueError
valid_data = valid_path.read_bytes()
truncated_real_file = tmp_path / "truncated_real.png"
truncated_real_file.write_bytes(valid_data[: len(valid_data) // 2])
with pytest.raises(ValueError, match="Failed to load image"):
image_io.load_file(truncated_real_file)

View File

@@ -68,17 +68,19 @@ class ImageMediaIO(MediaIO[Image.Image]):
return convert_image_mode(image, self.image_mode)
def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]:
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
try:
image = Image.open(BytesIO(data))
image.load()
image = self._convert_image_mode(image)
except (OSError, Image.UnidentifiedImageError) as e:
raise ValueError(f"Failed to load image: {e}") from e
return MediaWithBytes(image, data)
def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]:
return self.load_bytes(pybase64.b64decode(data, validate=True))
def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]:
with open(filepath, "rb") as f:
data = f.read()
image = Image.open(BytesIO(data))
return MediaWithBytes(self._convert_image_mode(image), data)
return self.load_bytes(filepath.read_bytes())
def encode_base64(
self,