From 9115f83afb2f232950bfb3da19f837ce8b5e7ab6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 16:05:55 +0000 Subject: [PATCH] debug: try/catch around mega_moe kernel with data diagnostics on crash --- deep_gemm/mega/__init__.py | 42 +++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index c0e81cb..84e3a7b 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -364,15 +364,33 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, ("sym_l2_acts", sym_buffer.l2_acts), ("sym_l2_acts_sf", sym_buffer.l2_acts_sf)]: print(f"[debug] {name}: dtype={t.dtype} shape={tuple(t.shape)} contig={t.is_contiguous()}", flush=True) - _C.fp8_nvfp4_mega_moe( - y, - (l1_w, l1_w_sf), (l2_w, l2_w_sf), - cumulative_local_expert_recv_stats, - sym_buffer.buffer, - sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), - sym_buffer.num_max_tokens_per_rank, - sym_buffer.num_experts, sym_buffer.num_topk, - recipe, - activation, activation_clamp, - fast_math - ) + try: + _C.fp8_nvfp4_mega_moe( + y, + (l1_w, l1_w_sf), (l2_w, l2_w_sf), + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) + torch.cuda.synchronize() + except RuntimeError as e: + print(f"[CRASH] fp8_nvfp4_mega_moe FAILED: {e}", flush=True) + print(f"[CRASH] num_tokens={y.shape[0]} hidden={y.shape[1]}", flush=True) + print(f"[CRASH] l1_w_sf shape={l1_w_sf.shape} strides={l1_w_sf.stride()}", flush=True) + print(f"[CRASH] l2_w_sf shape={l2_w_sf.shape} strides={l2_w_sf.stride()}", flush=True) + # Check activation data for anomalies + x = sym_buffer.x[:y.shape[0]] + print(f"[CRASH] x stats: min={x.min().item()} max={x.max().item()} nonzero={torch.count_nonzero(x).item()}", flush=True) + x_sf = sym_buffer.x_sf[:y.shape[0]] + print(f"[CRASH] x_sf stats: min={x_sf.min().item()} max={x_sf.max().item()} zeros={(x_sf==0).sum().item()}", flush=True) + # Check weight data for anomalies + print(f"[CRASH] l1_w stats: min={l1_w.min().item()} max={l1_w.max().item()}", flush=True) + print(f"[CRASH] l1_w_sf stats: min={l1_w_sf.min().item()} max={l1_w_sf.max().item()}", flush=True) + print(f"[CRASH] l2_w stats: min={l2_w.min().item()} max={l2_w.max().item()}", flush=True) + print(f"[CRASH] l2_w_sf stats: min={l2_w_sf.min().item()} max={l2_w_sf.max().item()}", flush=True) + raise