Multiple updates and refactorings (#280)
This commit is contained in:
@@ -5,6 +5,8 @@ from typing import Iterable
|
||||
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
if denominator == 0: # Which means that all elements in x and y are 0
|
||||
return 0.0
|
||||
sim = 2 * (x * y).sum() / denominator
|
||||
return 1 - sim
|
||||
|
||||
|
||||
Reference in New Issue
Block a user