From 755f9ad567426f4fd0bfcf149122d684215a7336 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 15 May 2026 17:55:11 +0000 Subject: [PATCH] debug: fix per_expert_alpha ref + clean up BF16 reference scaling --- src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 24 ++++++++-------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index 9e053269..8bfa6e3a 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -405,7 +405,7 @@ def nvfp4_mega_moe_full( print(f"[GEMM-DEBUG] l1_sf[e0] first 8: {l1_sf[e0].to(torch.float32).flatten()[:8].tolist()}") print(f"[GEMM-DEBUG] l1_global_sf[e0]: {l1_global_sf[e0].tolist()}") print(f"[GEMM-DEBUG] l1_alpha (igs): {l1_alpha:.6e}") - print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(per_expert_alpha[e0]):.6e}") + print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(l1_alpha * l1_global_sf[e0]):.6e}") # Dequantize activation x_u8 = x_fp4[s0].view(torch.uint8) @@ -416,8 +416,6 @@ def nvfp4_mega_moe_full( x_mags = _E2M1_MAGNITUDES.to(device=x_u8.device)[(x_nib & 0x07)] x_deq = x_signs * x_mags # (K,) = (7168,) sf_exp = x_sf[s0].to(torch.float32).repeat_interleave(16, dim=-1) # (K,) - igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale - x_bf16 = (x_deq * sf_exp * igs).to(torch.bfloat16) # Dequantize L1 weight for expert e0 w_u8 = l1_w[e0].view(torch.uint8) wlo = (w_u8 & 0x0F).long() @@ -426,24 +424,20 @@ def nvfp4_mega_moe_full( w_signs = (w_nib >> 3).float() * -2 + 1 w_mags = _E2M1_MAGNITUDES.to(device=w_u8.device)[(w_nib & 0x07)] w_deq = w_signs * w_mags # (K, N) = (7168, 6144) - # Weight SF: (sf_k, N) = (448, 6144). Each SF covers 16 FP4 values (8 bytes). - # repeat_interleave(16) on dim 0 gives (7168, 6144) — but wait, - # sf_k=448, and K=7168, so 448*16=7168. This is per-FP4-value already. w_sf_exp = l1_sf[e0].to(torch.float32).repeat_interleave(16, dim=0) # (K, N) + # Full dequant: x = e2m1 * block_sf * igs, w = e2m1 * block_sf * gs gs = l1_global_sf[e0] + igs = l1_alpha # already the input global scale + x_full = (x_deq * sf_exp * igs).to(torch.bfloat16) # (K,) + w_full = (w_deq * w_sf_exp).to(torch.bfloat16) # (K, N) without gs + ref_out = torch.nn.functional.linear(x_full.unsqueeze(0), w_full.T).squeeze(0) # (N,) + # Apply per-half global scale (gate_gs for first half, up_gs for second half) + gn = ref_out.shape[0] // 2 if gs.dim() == 0: - w_bf16 = (w_deq * w_sf_exp * gs.item()).to(torch.bfloat16) + ref_out = ref_out * gs.item() else: - w_bf16 = (w_deq * w_sf_exp).to(torch.bfloat16) - # GEMM computes: (w * w_sf) @ (x * x_sf) * alpha * gs - # We compute: (x * x_sf * igs) @ (w * w_sf * gs_per_half) - # which equals igs * gs * (x_sf * w_sf) @ (x * w) = same as GEMM - ref_out = torch.nn.functional.linear(x_bf16.unsqueeze(0), w_bf16.T).squeeze(0) - if gs.dim() > 0: - gn = ref_out.shape[0] // 2 ref_out[:gn] = ref_out[:gn] * gs[0].item() ref_out[gn:] = ref_out[gn:] * gs[1].item() - # DON'T multiply by igs again — already in x_bf16 nvfp4_mega_moe_full._ref_l1 = (s0, e0, ref_out) print(f"[BF16-REF-L1] expert={e0} amax={ref_out.abs().max():.4e} mean={ref_out.mean():.4e}") except Exception as ex: