debug: fix per_expert_alpha ref + clean up BF16 reference scaling
This commit is contained in:
@@ -405,7 +405,7 @@ def nvfp4_mega_moe_full(
|
||||
print(f"[GEMM-DEBUG] l1_sf[e0] first 8: {l1_sf[e0].to(torch.float32).flatten()[:8].tolist()}")
|
||||
print(f"[GEMM-DEBUG] l1_global_sf[e0]: {l1_global_sf[e0].tolist()}")
|
||||
print(f"[GEMM-DEBUG] l1_alpha (igs): {l1_alpha:.6e}")
|
||||
print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(per_expert_alpha[e0]):.6e}")
|
||||
print(f"[GEMM-DEBUG] per_expert_alpha[{e0}]: {float(l1_alpha * l1_global_sf[e0]):.6e}")
|
||||
|
||||
# Dequantize activation
|
||||
x_u8 = x_fp4[s0].view(torch.uint8)
|
||||
@@ -416,8 +416,6 @@ def nvfp4_mega_moe_full(
|
||||
x_mags = _E2M1_MAGNITUDES.to(device=x_u8.device)[(x_nib & 0x07)]
|
||||
x_deq = x_signs * x_mags # (K,) = (7168,)
|
||||
sf_exp = x_sf[s0].to(torch.float32).repeat_interleave(16, dim=-1) # (K,)
|
||||
igs = float(l1_global_scale) if not isinstance(l1_global_scale, float) else l1_global_scale
|
||||
x_bf16 = (x_deq * sf_exp * igs).to(torch.bfloat16)
|
||||
# Dequantize L1 weight for expert e0
|
||||
w_u8 = l1_w[e0].view(torch.uint8)
|
||||
wlo = (w_u8 & 0x0F).long()
|
||||
@@ -426,24 +424,20 @@ def nvfp4_mega_moe_full(
|
||||
w_signs = (w_nib >> 3).float() * -2 + 1
|
||||
w_mags = _E2M1_MAGNITUDES.to(device=w_u8.device)[(w_nib & 0x07)]
|
||||
w_deq = w_signs * w_mags # (K, N) = (7168, 6144)
|
||||
# Weight SF: (sf_k, N) = (448, 6144). Each SF covers 16 FP4 values (8 bytes).
|
||||
# repeat_interleave(16) on dim 0 gives (7168, 6144) — but wait,
|
||||
# sf_k=448, and K=7168, so 448*16=7168. This is per-FP4-value already.
|
||||
w_sf_exp = l1_sf[e0].to(torch.float32).repeat_interleave(16, dim=0) # (K, N)
|
||||
# Full dequant: x = e2m1 * block_sf * igs, w = e2m1 * block_sf * gs
|
||||
gs = l1_global_sf[e0]
|
||||
igs = l1_alpha # already the input global scale
|
||||
x_full = (x_deq * sf_exp * igs).to(torch.bfloat16) # (K,)
|
||||
w_full = (w_deq * w_sf_exp).to(torch.bfloat16) # (K, N) without gs
|
||||
ref_out = torch.nn.functional.linear(x_full.unsqueeze(0), w_full.T).squeeze(0) # (N,)
|
||||
# Apply per-half global scale (gate_gs for first half, up_gs for second half)
|
||||
gn = ref_out.shape[0] // 2
|
||||
if gs.dim() == 0:
|
||||
w_bf16 = (w_deq * w_sf_exp * gs.item()).to(torch.bfloat16)
|
||||
ref_out = ref_out * gs.item()
|
||||
else:
|
||||
w_bf16 = (w_deq * w_sf_exp).to(torch.bfloat16)
|
||||
# GEMM computes: (w * w_sf) @ (x * x_sf) * alpha * gs
|
||||
# We compute: (x * x_sf * igs) @ (w * w_sf * gs_per_half)
|
||||
# which equals igs * gs * (x_sf * w_sf) @ (x * w) = same as GEMM
|
||||
ref_out = torch.nn.functional.linear(x_bf16.unsqueeze(0), w_bf16.T).squeeze(0)
|
||||
if gs.dim() > 0:
|
||||
gn = ref_out.shape[0] // 2
|
||||
ref_out[:gn] = ref_out[:gn] * gs[0].item()
|
||||
ref_out[gn:] = ref_out[gn:] * gs[1].item()
|
||||
# DON'T multiply by igs again — already in x_bf16
|
||||
nvfp4_mega_moe_full._ref_l1 = (s0, e0, ref_out)
|
||||
print(f"[BF16-REF-L1] expert={e0} amax={ref_out.abs().max():.4e} mean={ref_out.mean():.4e}")
|
||||
except Exception as ex:
|
||||
|
||||
Reference in New Issue
Block a user