"""Debug: diagnose wo_a grouped GEMM issue step by step.""" import torch import torch.nn.functional as F import sys, os sys.path.insert(0, "/root/nvfp4-megamoe-kernel") from cutedsl.wo_a_grouped_linear import CuTeDSLNvfp4WoA from cutedsl.bridge import quantize_weight_to_nvfp4, quantize_to_nvfp4, quantize_activation_nvfp4 torch.cuda.set_device(0) torch.manual_seed(42) # Small dimensions G, HPG, HD, OR = 2, 4, 128, 64 GI = HPG * HD # 512 T = 4 DEVICE = "cuda:0" o = torch.randn(T, G*HPG, HD, dtype=torch.bfloat16, device=DEVICE) * 2.0 w = torch.randn(G*OR, GI, dtype=torch.bfloat16, device=DEVICE) * 0.1 # Reference: per-group BF16 matmul o_g = o.reshape(T, G, GI) z_ref = torch.empty(T, G, OR, dtype=torch.bfloat16, device=DEVICE) for g in range(G): z_ref[:, g, :] = o_g[:, g, :] @ w[g*OR:(g+1)*OR, :].T print(f"z_ref amax={z_ref.amax():.4f} shape={z_ref.shape}") print(f"z_ref[0, 0, :8] = {z_ref[0, 0, :8]}") # Step 1: verify weight quantization per-group print("\n=== Weight quant ===") for g in range(G): w_g = w[g*OR:(g+1)*OR, :].T # (GI, OR) w_fp4, w_sf, w_gs = quantize_weight_to_nvfp4(w_g) print(f" Group {g}: w_g shape={w_g.shape} w_fp4 shape={w_fp4.shape} w_sf shape={w_sf.shape} gs={w_gs:.6f}") # Step 2: test runner directly (bypass custom op) runner = CuTeDSLNvfp4WoA( n_local_groups=G, heads_per_group=HPG, head_dim=HD, o_lora_rank=OR, max_num_tokens=8, device=DEVICE, ) runner.set_bf16_weight(w) runner.finalize_weights() runner._ensure_initialized() # Compute activation gs with torch.no_grad(): _, _, gs = quantize_to_nvfp4(o_g[:, 0, :]) # use first group's activation print(f"\nActivation gs from sample: {gs:.6f}") print(f"Runner gs: {runner._activation_global_scale:.6f}") runner._activation_global_scale = gs # use the right one # Call _run_impl directly with torch.no_grad(): z_out = runner._run_impl(o) print(f"\nz_out shape={z_out.shape} amax={z_out.amax():.4f}") print(f"z_out[0, 0, :8] = {z_out[0, 0, :8]}") # Per-group comparison for g in range(G): cos = F.cosine_similarity(z_ref[:, g, :].flatten().unsqueeze(0).float(), z_out[:, g, :].flatten().unsqueeze(0).float()).item() print(f" Group {g}: cosine={cos:.6f} ref_amax={z_ref[:, g, :].amax():.4f} out_amax={z_out[:, g, :].amax():.4f}") cos = F.cosine_similarity(z_ref.flatten().unsqueeze(0).float(), z_out.flatten().unsqueeze(0).float()).item() print(f"\nOverall cosine={cos:.6f}")