Serialize tensors using int8 views (#16866)
Signed-off-by: Staszek Pasko <staszek@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -47,6 +47,10 @@ def test_encode_decode():
|
||||
torch.rand((1, 10), dtype=torch.float32),
|
||||
torch.rand((3, 5, 4000), dtype=torch.float64),
|
||||
torch.tensor(1984), # test scalar too
|
||||
# Make sure to test bf16 which numpy doesn't support.
|
||||
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
|
||||
torch.tensor([float("-inf"), float("inf")] * 1024,
|
||||
dtype=torch.bfloat16),
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
@@ -64,7 +68,7 @@ def test_encode_decode():
|
||||
# There should be the main buffer + 4 large tensor buffers
|
||||
# + 1 large numpy array. "large" is <= 512 bytes.
|
||||
# The two small tensors are encoded inline.
|
||||
assert len(encoded) == 6
|
||||
assert len(encoded) == 8
|
||||
|
||||
decoded: MyType = decoder.decode(encoded)
|
||||
|
||||
@@ -76,7 +80,7 @@ def test_encode_decode():
|
||||
|
||||
encoded2 = encoder.encode_into(obj, preallocated)
|
||||
|
||||
assert len(encoded2) == 6
|
||||
assert len(encoded2) == 8
|
||||
assert encoded2[0] is preallocated
|
||||
|
||||
decoded2: MyType = decoder.decode(encoded2)
|
||||
@@ -114,15 +118,15 @@ def test_multimodal_kwargs():
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 44536, +-20 for minor changes
|
||||
assert total_len >= 44516 and total_len <= 44556
|
||||
# expected total encoding length, should be 44559, +-20 for minor changes
|
||||
assert total_len >= 44539 and total_len <= 44579
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||
|
||||
|
||||
def test_multimodal_items_by_modality():
|
||||
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
|
||||
dtype=torch.int16),
|
||||
e1 = MultiModalFieldElem("audio", "a0",
|
||||
torch.zeros(1000, dtype=torch.bfloat16),
|
||||
MultiModalBatchedField())
|
||||
e2 = MultiModalFieldElem(
|
||||
"video",
|
||||
|
||||
Reference in New Issue
Block a user