[MultiModal] add support for numpy array embeddings (#38119)
Signed-off-by: guillaume_guy <guillaume.guy@airbnb.com> Signed-off-by: Guillaume Guy <guillaume.c.guy@gmail.com> Co-authored-by: guillaume_guy <guillaume.guy@airbnb.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ out-of-bounds memory writes during to_dense() operations.
|
||||
|
||||
import io
|
||||
|
||||
import numpy as np
|
||||
import pybase64 as base64
|
||||
import pytest
|
||||
import torch
|
||||
@@ -190,6 +191,51 @@ class TestImageEmbedsValidation:
|
||||
with pytest.raises((RuntimeError, ValueError)):
|
||||
io_handler.load_bytes(buffer.read())
|
||||
|
||||
def test_valid_numpy_tensor_accepted(self):
|
||||
"""numpy .npy format should load and return correct tensor."""
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
||||
buf = io.BytesIO()
|
||||
np.save(buf, arr)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
result = io_handler.load_base64("", encoded)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == torch.Size([2, 3])
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.from_numpy(arr))
|
||||
|
||||
def test_numpy_int32_tensor_accepted(self):
|
||||
"""numpy int32 arrays should round-trip correctly."""
|
||||
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.arange(280, dtype=np.int32)
|
||||
buf = io.BytesIO()
|
||||
np.save(buf, arr)
|
||||
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||
|
||||
result = io_handler.load_base64("", encoded)
|
||||
assert result.dtype == torch.int32
|
||||
assert result.shape == torch.Size([280])
|
||||
assert (result == torch.from_numpy(arr)).all()
|
||||
|
||||
def test_load_file_numpy_tensor_accepted(self, tmp_path):
|
||||
"""numpy .npy files should load correctly via load_file."""
|
||||
|
||||
io_handler = ImageEmbeddingMediaIO()
|
||||
|
||||
arr = np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32)
|
||||
npy_path = tmp_path / "image_embeds.npy"
|
||||
np.save(npy_path, arr)
|
||||
|
||||
result = io_handler.load_file(npy_path)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.shape == torch.Size([2, 2])
|
||||
assert result.dtype == torch.float32
|
||||
assert torch.allclose(result, torch.from_numpy(arr))
|
||||
|
||||
|
||||
class TestAudioEmbedsValidation:
|
||||
"""Test sparse tensor validation in audio embeddings (Chat API)."""
|
||||
|
||||
Reference in New Issue
Block a user