Multiple updates and refactorings (#280)

This commit is contained in:
Zhean Xu
2026-01-16 17:06:52 +08:00
committed by GitHub
parent 3ccf40c53a
commit 0f5f266202
55 changed files with 2706 additions and 891 deletions

View File

@@ -5,6 +5,8 @@ from typing import Iterable
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
if denominator == 0: # Which means that all elements in x and y are 0
return 0.0
sim = 2 * (x * y).sum() / denominator
return 1 - sim