diff --git a/deep_gemm/mega/__init__.py b/deep_gemm/mega/__init__.py index 84e3a7b..5c89b3d 100644 --- a/deep_gemm/mega/__init__.py +++ b/deep_gemm/mega/__init__.py @@ -383,14 +383,7 @@ def fp8_nvfp4_mega_moe(y: torch.Tensor, print(f"[CRASH] num_tokens={y.shape[0]} hidden={y.shape[1]}", flush=True) print(f"[CRASH] l1_w_sf shape={l1_w_sf.shape} strides={l1_w_sf.stride()}", flush=True) print(f"[CRASH] l2_w_sf shape={l2_w_sf.shape} strides={l2_w_sf.stride()}", flush=True) - # Check activation data for anomalies - x = sym_buffer.x[:y.shape[0]] - print(f"[CRASH] x stats: min={x.min().item()} max={x.max().item()} nonzero={torch.count_nonzero(x).item()}", flush=True) - x_sf = sym_buffer.x_sf[:y.shape[0]] - print(f"[CRASH] x_sf stats: min={x_sf.min().item()} max={x_sf.max().item()} zeros={(x_sf==0).sum().item()}", flush=True) - # Check weight data for anomalies - print(f"[CRASH] l1_w stats: min={l1_w.min().item()} max={l1_w.max().item()}", flush=True) - print(f"[CRASH] l1_w_sf stats: min={l1_w_sf.min().item()} max={l1_w_sf.max().item()}", flush=True) - print(f"[CRASH] l2_w stats: min={l2_w.min().item()} max={l2_w.max().item()}", flush=True) - print(f"[CRASH] l2_w_sf stats: min={l2_w_sf.min().item()} max={l2_w_sf.max().item()}", flush=True) + print(f"[CRASH] x shape={sym_buffer.x.shape} x_sf shape={sym_buffer.x_sf.shape}", flush=True) + print(f"[CRASH] recipe={recipe} activation={activation}", flush=True) + # Don't touch GPU tensors — CUDA is broken after 715 raise