2025-07-18 11:32:22 +08:00
|
|
|
import torch
|
|
|
|
|
from typing import Iterable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def calc_diff(x: torch.Tensor, y: torch.Tensor):
|
|
|
|
|
x, y = x.double(), y.double()
|
|
|
|
|
denominator = (x * x + y * y).sum()
|
2026-01-16 17:06:52 +08:00
|
|
|
if denominator == 0: # Which means that all elements in x and y are 0
|
|
|
|
|
return 0.0
|
2025-07-18 11:32:22 +08:00
|
|
|
sim = 2 * (x * y).sum() / denominator
|
|
|
|
|
return 1 - sim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def count_bytes(*tensors):
|
|
|
|
|
total = 0
|
|
|
|
|
for t in tensors:
|
|
|
|
|
if isinstance(t, (tuple, list)):
|
|
|
|
|
total += count_bytes(*t)
|
|
|
|
|
elif t is not None:
|
|
|
|
|
total += t.numel() * t.element_size()
|
|
|
|
|
return total
|