"""Debug: compare NVFP4 grouped GEMM output element-by-element.""" import torch import torch.nn.functional as F import sys, os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from cutedsl.bridge import quantize_weight_to_nvfp4, quantize_to_nvfp4 torch.cuda.set_device(0) torch.manual_seed(42) G, HPG, HD, OR = 2, 4, 128, 64 GI = HPG * HD # 512 T = 4 DEVICE = "cuda:0" o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 # Reference: per-group BF16 matmul o_g = o.reshape(T, G, GI) z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) for g in range(G): z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T # Test: quantize/dequantize each weight group and compare print("=== Weight quantization test ===") for g in range(G): w_g = w[g*OR:(g*OR+OR), :] # (OR, GI) w_gt = w_g.T # (GI, OR) for quantize_weight_to_nvfp4 w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_gt) # Dequantize to BF16 for reference E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], dtype=torch.float32, device=DEVICE) packed = w_fp4.view(torch.uint8) lower = E2M1_LUT[(packed & 0x0F).long()] upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] K, N = w_gt.shape unpacked = torch.empty(K, N, dtype=torch.float32, device=DEVICE) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper K_sf = w_sf.shape[0] sf_expanded = w_sf.float().repeat_interleave(16, dim=0)[:K, :] w_dequant = (unpacked * sf_expanded * w_gs).to(torch.bfloat16) # Compare cos = F.cosine_similarity(w_gt.flatten().unsqueeze(0).float(), w_dequant.flatten().unsqueeze(0).float()).item() print(f" Group {g}: weight quant cos={cos:.6f} w_gt amax={w_gt.amax():.4f} w_dequant amax={w_dequant.amax():.4f}") # Test: activation quantization print("\n=== Activation quantization test ===") o_flat = o_g.reshape(T * G, GI) x_fp4, x_sf, gs = quantize_to_nvfp4(o_flat) # Dequant packed = x_fp4.view(torch.uint8) E2M1_LUT = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6., -0., -0.5, -1., -1.5, -2., -3., -4., -6.], dtype=torch.float32, device=DEVICE) lower = E2M1_LUT[(packed & 0x0F).long()] upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] unpacked = torch.empty(T*G, GI, dtype=torch.float32, device=DEVICE) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper K_sf = x_sf.shape[1] sf_expanded = x_sf.float().repeat_interleave(16, dim=1)[:T*G, :GI] x_dequant = (unpacked * sf_expanded * gs).to(torch.bfloat16) cos = F.cosine_similarity(o_flat.flatten().unsqueeze(0).float(), x_dequant.flatten().unsqueeze(0).float()).item() print(f" Activation quant cos={cos:.6f} gs={gs:.6f}") # Test: the FULL pipeline — quantize weight and activation, then BF16 matmul print("\n=== Full pipeline (quantize → dequantize → BF16 matmul) ===") z_qdq = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) for g in range(G): w_g = w[g*OR:(g*OR+OR), :].T # (GI, OR) w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g) # Dequant packed = w_fp4.view(torch.uint8) lower = E2M1_LUT[(packed & 0x0F).long()] upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] K, N = w_g.shape unpacked = torch.empty(K, N, dtype=torch.float32, device=DEVICE) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper K_sf = w_sf.shape[0] sf_expanded = w_sf.float().repeat_interleave(16, dim=0)[:K, :] w_dequant = (unpacked * sf_expanded * w_gs).to(torch.bfloat16) # Quantize activation for this group act = o_g[:, g, :] # (T, GI) a_fp4, a_sf, a_gs = quantize_to_nvfp4(act) packed = a_fp4.view(torch.uint8) lower = E2M1_LUT[(packed & 0x0F).long()] upper = E2M1_LUT[((packed >> 4) & 0x0F).long()] unpacked = torch.empty(T, GI, dtype=torch.float32, device=DEVICE) unpacked[:, 0::2] = lower unpacked[:, 1::2] = upper K_sf = a_sf.shape[1] sf_expanded = a_sf.float().repeat_interleave(16, dim=1)[:T, :GI] a_dequant = (unpacked * sf_expanded * a_gs).to(torch.bfloat16) z_qdq[:, g, :] = a_dequant @ w_dequant cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_qdq.flatten().unsqueeze(0).float()).item() print(f" QDQ vs BF16: cosine={cos:.6f}")