From cc3e3da45c315d16cbb8f6e1de3a95ae37a79881 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 15:30:38 +0000 Subject: [PATCH] debug: check for zero/NaN/Inf in weight SF values --- deep_gemm/mega/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index bd459d9..9d954dd 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -346,6 +346,17 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, ("l2_w", l2_w), ("l2_w_sf", l2_w_sf)]: print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} strides={t.stride()} contig={t.is_contiguous()}", flush=True) + # Sanity check: zero/NaN/Inf in weight SF → illegal instruction in MMA + for name, sf in [("l1_w_sf", l1_w_sf), ("l2_w_sf", l2_w_sf)]: + zero_pct = (sf == 0).float().mean().item() * 100 + if zero_pct > 50: + print(f"[WARN] {name}: {zero_pct:.1f}% zeros in SF! Possible div-by-zero", flush=True) + sf_u8 = sf.view(torch.uint8) + nan_count = (sf_u8 == 0x7F).sum().item() + inf_count = (sf_u8 == 0x7E).sum().item() + if nan_count > 0 or inf_count > 0: + print(f"[WARN] {name}: {nan_count} NaN bytes, {inf_count} Inf bytes in UE4M3 scales!", flush=True) + for name, t in [("sym_x", sym_buffer.x), ("sym_x_sf", sym_buffer.x_sf), ("sym_l1_acts", sym_buffer.l1_acts), ("sym_l1_acts_sf", sym_buffer.l1_acts_sf), ("sym_l2_acts", sym_buffer.l2_acts), ("sym_l2_acts_sf", sym_buffer.l2_acts_sf)]: