[Core] Do not copy array during hashing (#19484)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -60,3 +60,15 @@ def test_hash_collision_array_shape():
|
||||
|
||||
hasher = MultiModalHasher
|
||||
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
|
||||
|
||||
|
||||
def test_hash_non_contiguous_array():
|
||||
arr = np.arange(24).reshape(4, 6).T
|
||||
assert not arr.flags.c_contiguous
|
||||
|
||||
arr_c = np.ascontiguousarray(arr)
|
||||
assert arr_c.flags.c_contiguous
|
||||
|
||||
hasher = MultiModalHasher
|
||||
# Both should be hashable and produce the same hashes
|
||||
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import pickle
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -23,11 +24,11 @@ A dictionary containing hashes for items in each modality.
|
||||
class MultiModalHasher:
|
||||
|
||||
@classmethod
|
||||
def serialize_item(cls, obj: object) -> bytes:
|
||||
def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
|
||||
# Simple cases
|
||||
if isinstance(obj, str):
|
||||
return obj.encode("utf-8")
|
||||
if isinstance(obj, bytes):
|
||||
if isinstance(obj, (bytes, memoryview)):
|
||||
return obj
|
||||
if isinstance(obj, (int, float)):
|
||||
return np.array(obj).tobytes()
|
||||
@@ -38,12 +39,13 @@ class MultiModalHasher:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return cls.item_to_bytes("tensor", obj.numpy())
|
||||
if isinstance(obj, np.ndarray):
|
||||
return cls.item_to_bytes(
|
||||
"ndarray", {
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": obj.tobytes(),
|
||||
})
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
|
||||
return cls.item_to_bytes("ndarray", {
|
||||
"dtype": obj.dtype.str,
|
||||
"shape": obj.shape,
|
||||
"data": arr_data,
|
||||
})
|
||||
|
||||
logger.warning(
|
||||
"No serialization method found for %s. "
|
||||
@@ -64,7 +66,7 @@ class MultiModalHasher:
|
||||
cls,
|
||||
key: str,
|
||||
obj: object,
|
||||
) -> Iterable[tuple[bytes, bytes]]:
|
||||
) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]:
|
||||
# Recursive cases
|
||||
if isinstance(obj, (list, tuple)):
|
||||
for i, elem in enumerate(obj):
|
||||
@@ -73,7 +75,7 @@ class MultiModalHasher:
|
||||
for k, v in obj.items():
|
||||
yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
|
||||
else:
|
||||
key_bytes = cls.serialize_item(key)
|
||||
key_bytes = key.encode("utf-8")
|
||||
value_bytes = cls.serialize_item(obj)
|
||||
yield key_bytes, value_bytes
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class MsgpackEncoder:
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
|
||||
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
|
||||
if not obj.shape or obj.nbytes < self.size_threshold:
|
||||
# Encode small arrays and scalars inline. Using this extension type
|
||||
# ensures we can avoid copying when decoding.
|
||||
|
||||
Reference in New Issue
Block a user