Serialize tensors using int8 views (#16866)

Signed-off-by: Staszek Pasko <staszek@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Staszek Paśko
2025-04-19 19:28:34 +02:00
committed by GitHub
parent 682e0b6d2f
commit 87aaadef73
2 changed files with 48 additions and 13 deletions

View File

@@ -47,6 +47,10 @@ def test_encode_decode():
torch.rand((1, 10), dtype=torch.float32),
torch.rand((3, 5, 4000), dtype=torch.float64),
torch.tensor(1984), # test scalar too
# Make sure to test bf16 which numpy doesn't support.
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
torch.tensor([float("-inf"), float("inf")] * 1024,
dtype=torch.bfloat16),
],
numpy_array=np.arange(512),
unrecognized=UnrecognizedType(33),
@@ -64,7 +68,7 @@ def test_encode_decode():
# There should be the main buffer + 4 large tensor buffers
# + 1 large numpy array. "large" is <= 512 bytes.
# The two small tensors are encoded inline.
assert len(encoded) == 6
assert len(encoded) == 8
decoded: MyType = decoder.decode(encoded)
@@ -76,7 +80,7 @@ def test_encode_decode():
encoded2 = encoder.encode_into(obj, preallocated)
assert len(encoded2) == 6
assert len(encoded2) == 8
assert encoded2[0] is preallocated
decoded2: MyType = decoder.decode(encoded2)
@@ -114,15 +118,15 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 44536, +-20 for minor changes
assert total_len >= 44516 and total_len <= 44556
# expected total encoding length, should be 44559, +-20 for minor changes
assert total_len >= 44539 and total_len <= 44579
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
assert all(nested_equal(d[k], decoded[k]) for k in d)
def test_multimodal_items_by_modality():
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
dtype=torch.int16),
e1 = MultiModalFieldElem("audio", "a0",
torch.zeros(1000, dtype=torch.bfloat16),
MultiModalBatchedField())
e2 = MultiModalFieldElem(
"video",