fix: cosine_similarity dim for M>0
This commit is contained in:
@@ -78,7 +78,7 @@ nvfp4_out = cutlass_nvfp4_blockscaled_gemm(
|
||||
print(f"NVFP4 out: amax={nvfp4_out.abs().max():.4e} mean={nvfp4_out.mean():.4e}")
|
||||
|
||||
# Cosine similarity
|
||||
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float().unsqueeze(0), quant_ref.float().unsqueeze(0)).item()
|
||||
cos = torch.nn.functional.cosine_similarity(nvfp4_out.float(), quant_ref.float(), dim=-1).mean().item()
|
||||
mse = (nvfp4_out.float() - quant_ref.float()).pow(2).mean().item()
|
||||
print(f"cosine={cos:.6f} mse={mse:.4e}")
|
||||
|
||||
@@ -97,5 +97,5 @@ sf2_exp = x2_sf.to(torch.float32).repeat_interleave(16, dim=-1)
|
||||
x2_recon = (x2_deq * sf2_exp).to(torch.bfloat16)
|
||||
qr2 = torch.nn.functional.linear(x2_recon, w_recon.T)
|
||||
nv2 = cutlass_nvfp4_blockscaled_gemm(x2_fp4, x2_sf, w_fp4_final, w_sf_final, M2, N, K, alpha=1.0)
|
||||
cos2 = torch.nn.functional.cosine_similarity(nv2.float().unsqueeze(0), qr2.float().unsqueeze(0)).item()
|
||||
cos2 = torch.nn.functional.cosine_similarity(nv2.float(), qr2.float(), dim=-1).mean().item()
|
||||
print(f"M=128: cosine={cos2:.6f}")
|
||||
|
||||
Reference in New Issue
Block a user