From fee5a97ebbb0c7d0a7a8fb3783dc52c1abb8122a Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 18:50:45 +0000 Subject: [PATCH] fix: cosine_similarity dim for M>0 --- test_m1_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_m1_gemm.py b/test_m1_gemm.py index 1457ac28..1375595f 100644 --- a/test_m1_gemm.py +++ b/test_m1_gemm.py @@ -78,7 +78,7 @@ nvfp4_out = cutlass_nvfp4_blockscaled_gemm( print(f"NVFP4 out: amax={nvfp4_out.abs().max():.4e} mean={nvfp4_out.mean():.4e}") # Cosine similarity -cos = torch.nn.functional.cosine_similarity(nvfp4_out.float().unsqueeze(0), quant_ref.float().unsqueeze(0)).item() +cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item() mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item() print(f"cosine={cos:.6f} mse={mse:.4e}") @@ -97,5 +97,5 @@ sf2_exp = x2_sf.to(torch.float32).repeat_interleave(16, dim=-1) x2_recon = (x2_deq * sf2_exp).to(torch.bfloat16) qr2 = torch.nn.functional.linear(x2_recon, w_recon.T) nv2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w_fp4_final, w_sf_final, M2, N, K, alpha=1.0) -cos2 = torch.nn.functional.cosine_similarity(nv2.float().unsqueeze(0), qr2.float().unsqueeze(0)).item() +cos2 = torch.nn.functional.cosine_similarity(nv2.float(), qr2.float(), dim=-1).mean().item() print(f"M=128: cosine={cos2:.6f}")