[MultiModal] add support for numpy array embeddings (#38119)

Signed-off-by: guillaume_guy <guillaume.guy@airbnb.com>
Signed-off-by: Guillaume Guy <guillaume.c.guy@gmail.com>
Co-authored-by: guillaume_guy <guillaume.guy@airbnb.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Guillaume Guy
2026-03-25 15:13:04 -05:00
committed by GitHub
parent 978fc18bf0
commit 70a2152830
2 changed files with 63 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ out-of-bounds memory writes during to_dense() operations.
import io
import numpy as np
import pybase64 as base64
import pytest
import torch
@@ -190,6 +191,51 @@ class TestImageEmbedsValidation:
with pytest.raises((RuntimeError, ValueError)):
io_handler.load_bytes(buffer.read())
def test_valid_numpy_tensor_accepted(self):
"""numpy .npy format should load and return correct tensor."""
io_handler = ImageEmbeddingMediaIO()
arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
buf = io.BytesIO()
np.save(buf, arr)
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
result = io_handler.load_base64("", encoded)
assert isinstance(result, torch.Tensor)
assert result.shape == torch.Size([2, 3])
assert result.dtype == torch.float32
assert torch.allclose(result, torch.from_numpy(arr))
def test_numpy_int32_tensor_accepted(self):
"""numpy int32 arrays should round-trip correctly."""
io_handler = ImageEmbeddingMediaIO()
arr = np.arange(280, dtype=np.int32)
buf = io.BytesIO()
np.save(buf, arr)
encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
result = io_handler.load_base64("", encoded)
assert result.dtype == torch.int32
assert result.shape == torch.Size([280])
assert (result == torch.from_numpy(arr)).all()
def test_load_file_numpy_tensor_accepted(self, tmp_path):
"""numpy .npy files should load correctly via load_file."""
io_handler = ImageEmbeddingMediaIO()
arr = np.array([[1.5, 2.5], [3.5, 4.5]], dtype=np.float32)
npy_path = tmp_path / "image_embeds.npy"
np.save(npy_path, arr)
result = io_handler.load_file(npy_path)
assert isinstance(result, torch.Tensor)
assert result.shape == torch.Size([2, 2])
assert result.dtype == torch.float32
assert torch.allclose(result, torch.from_numpy(arr))
class TestAudioEmbedsValidation:
"""Test sparse tensor validation in audio embeddings (Chat API)."""