debug: fix per_expert_alpha ref + clean up BF16 reference scaling

This commit is contained in:
2026-05-15 17:55:11 +00:00
parent de8acc7965
commit 755f9ad567

View File

@@ -405,7 +405,7 @@ def nvfp4_mega_moe_full(
print(f"[GEMM-DEBUG] l1_sf[e0] first 8: {l1_sf[e0].to(torch.float32).flatten()[:8].tolist()}")
print(f"[GEMM-DEBUG] l1_global_sf[e0]: {l1_global_sf[e0].tolist()}")
print(f"[GEMM-DEBUG] l1_alpha (igs): {l1_alpha:.6e}")
print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(per_expert_alpha[e0]):.6e}")
print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(l1_alpha * l1_global_sf[e0]):.6e}")
# Dequantize activation
x_u8 = x_fp4[s0].view(torch.uint8)
@@ -416,8 +416,6 @@ def nvfp4_mega_moe_full(
x_mags = _E2M1_MAGNITUDES.to(device=x_u8.device)[(x_nib & 0x07)]
x_deq = x_signs * x_mags # (K,) = (7168,)
sf_exp = x_sf[s0].to(torch.float32).repeat_interleave(16, dim=-1) # (K,)
igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
x_bf16 = (x_deq * sf_exp * igs).to(torch.bfloat16)
# Dequantize L1 weight for expert e0
w_u8 = l1_w[e0].view(torch.uint8)
wlo = (w_u8 & 0x0F).long()
@@ -426,24 +424,20 @@ def nvfp4_mega_moe_full(
w_signs = (w_nib >> 3).float() * -2 + 1
w_mags = _E2M1_MAGNITUDES.to(device=w_u8.device)[(w_nib & 0x07)]
w_deq = w_signs * w_mags # (K, N) = (7168, 6144)
# Weight SF: (sf_k, N) = (448, 6144). Each SF covers 16 FP4 values (8 bytes).
# repeat_interleave(16) on dim 0 gives (7168, 6144) — but wait,
# sf_k=448, and K=7168, so 448*16=7168. This is per-FP4-value already.
w_sf_exp = l1_sf[e0].to(torch.float32).repeat_interleave(16, dim=0) # (K, N)
# Full dequant: x = e2m1 * block_sf * igs, w = e2m1 * block_sf * gs
gs = l1_global_sf[e0]
igs = l1_alpha # already the input global scale
x_full = (x_deq * sf_exp * igs).to(torch.bfloat16) # (K,)
w_full = (w_deq * w_sf_exp).to(torch.bfloat16) # (K, N) without gs
ref_out = torch.nn.functional.linear(x_full.unsqueeze(0), w_full.T).squeeze(0) # (N,)
# Apply per-half global scale (gate_gs for first half, up_gs for second half)
gn = ref_out.shape[0] // 2
if gs.dim() == 0:
w_bf16 = (w_deq * w_sf_exp * gs.item()).to(torch.bfloat16)
ref_out = ref_out * gs.item()
else:
w_bf16 = (w_deq * w_sf_exp).to(torch.bfloat16)
# GEMM computes: (w * w_sf) @ (x * x_sf) * alpha * gs
# We compute: (x * x_sf * igs) @ (w * w_sf * gs_per_half)
# which equals igs * gs * (x_sf * w_sf) @ (x * w) = same as GEMM
ref_out = torch.nn.functional.linear(x_bf16.unsqueeze(0), w_bf16.T).squeeze(0)
if gs.dim() > 0:
gn = ref_out.shape[0] // 2
ref_out[:gn] = ref_out[:gn] * gs[0].item()
ref_out[gn:] = ref_out[gn:] * gs[1].item()
# DON'T multiply by igs again — already in x_bf16
nvfp4_mega_moe_full._ref_l1 = (s0, e0, ref_out)
print(f"[BF16-REF-L1] expert={e0} amax={ref_out.abs().max():.4e} mean={ref_out.mean():.4e}")
except Exception as ex: