fix: cosine_similarity dim for M>0

This commit is contained in:
2026-05-15 18:50:45 +00:00
parent f9330a1777
commit fee5a97ebb

View File

@@ -78,7 +78,7 @@ nvfp4_out = cutlass_nvfp4_blockscaled_gemm(
print(f"NVFP4 out: amax={nvfp4_out.abs().max():.4e} mean={nvfp4_out.mean():.4e}")
# Cosine similarity
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float().unsqueeze(0), quant_ref.float().unsqueeze(0)).item()
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
print(f"cosine={cos:.6f} mse={mse:.4e}")
@@ -97,5 +97,5 @@ sf2_exp = x2_sf.to(torch.float32).repeat_interleave(16, dim=-1)
x2_recon = (x2_deq * sf2_exp).to(torch.bfloat16)
qr2 = torch.nn.functional.linear(x2_recon, w_recon.T)
nv2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w_fp4_final, w_sf_final, M2, N, K, alpha=1.0)
cos2 = torch.nn.functional.cosine_similarity(nv2.float().unsqueeze(0), qr2.float().unsqueeze(0)).item()
cos2 = torch.nn.functional.cosine_similarity(nv2.float(), qr2.float(), dim=-1).mean().item()
print(f"M=128: cosine={cos2:.6f}")