[Bugfix] Fix incorrect original shape in hashing (#23672)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Cyrus Leung
2025-08-27 03:01:25 +08:00
committed by GitHub
parent 98aa16ff41
commit 9715f7bb0f
2 changed files with 12 additions and 5 deletions

View File

@@ -45,10 +45,11 @@ def test_hash_collision_image_transpose():
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_tensor_shape():
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_hash_collision_tensor_shape(dtype):
# The hash should be different though the data is the same when flattened
arr1 = torch.zeros((5, 10, 20, 3))
arr2 = torch.zeros((10, 20, 5, 3))
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)
hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)