[MultiModal] add support for numpy array embeddings (#38119)
Signed-off-by: guillaume_guy <guillaume.guy@airbnb.com> Signed-off-by: Guillaume Guy <guillaume.c.guy@gmail.com> Co-authored-by: guillaume_guy <guillaume.guy@airbnb.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ out-of-bounds memory writes during to_dense() operations.
|
||||
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pybase64 as base64
|
||||
import pytest
|
||||
import torch
|
||||
@@ -190,6 +191,51 @@ class TestImageEmbedsValidation:
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
def test_valid_numpy_tensor_accepted(self):
|
||||
"""numpy .npy format should load and return correct tensor."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
||||
buf = io.BytesIO()
|
||||
np.save(buf, arr)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
result = io_handler.load_base64("", encoded)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == torch.Size([2, 3])
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.from_numpy(arr))
|
||||
|
||||
def test_numpy_int32_tensor_accepted(self):
|
||||
"""numpy int32 arrays should round-trip correctly."""
|
||||
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.arange(280, dtype=np.int32)
|
||||
buf = io.BytesIO()
|
||||
np.save(buf, arr)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
result = io_handler.load_base64("", encoded)
|
||||
assert result.dtype == torch.int32
|
||||
assert result.shape == torch.Size([280])
|
||||
assert (result == torch.from_numpy(arr)).all()
|
||||
|
||||
def test_load_file_numpy_tensor_accepted(self, tmp_path):
|
||||
"""numpy .npy files should load correctly via load_file."""
|
||||
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32)
|
||||
npy_path = tmp_path / "image_embeds.npy"
|
||||
np.save(npy_path, arr)
|
||||
|
||||
result = io_handler.load_file(npy_path)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == torch.Size([2, 2])
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.from_numpy(arr))
|
||||
|
||||
|
||||
class TestAudioEmbedsValidation:
|
||||
"""Test sparse tensor validation in audio embeddings (Chat API)."""
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pybase64
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -13,6 +14,8 @@ from vllm.utils.serial_utils import tensor2base64
|
||||
from ..image import convert_image_mode, rgba_to_rgb
|
||||
from .base import MediaIO, MediaWithBytes
|
||||
|
||||
MAGIC_NUMPY_PREFIX = b"\x93NUMPY" # https://numpy.org/devdocs/reference/generated/numpy.lib.format.html#format-version-1-0
|
||||
|
||||
|
||||
class ImageMediaIO(MediaIO[Image.Image]):
|
||||
"""Configuration values can be user-provided either by --media-io-kwargs or
|
||||
@@ -104,7 +107,7 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
def _load_pickled_torch(self, data: bytes) -> torch.Tensor:
|
||||
buffer = BytesIO(data)
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
@@ -112,12 +115,23 @@ class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
|
||||
tensor = torch.load(buffer, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
def _load_numpy(self, data: bytes) -> torch.Tensor:
|
||||
with BytesIO(data) as buffer:
|
||||
return torch.from_numpy(np.load(buffer))
|
||||
|
||||
def load_bytes(self, data: bytes) -> torch.Tensor:
|
||||
if data[:6] == MAGIC_NUMPY_PREFIX:
|
||||
return self._load_numpy(data)
|
||||
|
||||
return self._load_pickled_torch(data)
|
||||
|
||||
def load_base64(self, media_type: str, data: str) -> torch.Tensor:
|
||||
return self.load_bytes(pybase64.b64decode(data, validate=True))
|
||||
|
||||
def load_file(self, filepath: Path) -> torch.Tensor:
|
||||
# Enable sparse tensor integrity checks to prevent out-of-bounds
|
||||
# writes from maliciously crafted tensors
|
||||
if filepath.suffix == ".npy":
|
||||
return torch.from_numpy(np.load(filepath))
|
||||
|
||||
with torch.sparse.check_sparse_tensor_invariants():
|
||||
tensor = torch.load(filepath, weights_only=True)
|
||||
return tensor.to_dense()
|
||||
|
||||
Reference in New Issue
Block a user