From 2bb52c7caec2b2dc6e1a62f30d5441295e5a2f78 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 6 Jun 2026 07:40:19 +0000 Subject: [PATCH] =?UTF-8?q?Add=20per-layer=20graph=20capture=20verificatio?= =?UTF-8?q?n=20=E2=80=94=20replay=20immediately=20and=20check=20for=20zero?= =?UTF-8?q?s?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- single_shot_inference.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/single_shot_inference.py b/single_shot_inference.py index efe9fe06..04dbf2fd 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -357,6 +357,18 @@ class CUDAGraphDecoder: self.graphs_a[li] = graph_a + # Verify Graph A capture: replay immediately and check output is non-zero + # This catches issues like wrong device, stale data, or broken kernel arguments + if li < 3 or (li + 1) % 20 == 0: + torch.cuda.set_device(gpu) + graph_a.replay() + torch.cuda.synchronize() + xn_max = self.x_normed_bufs[li].abs().max().item() + qh_max = self.q_heads_bufs[li].abs().max().item() + print(f" L{li} GraphA verify: x_normed |X|={xn_max:.4f} q_heads |X|={qh_max:.4f}", flush=True) + if xn_max == 0.0: + print(f" *** L{li} GraphA VERIFY FAILED: x_normed is all zeros! ***", flush=True) + # ======== Graph B: post-attention + FFN compute ======== # Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li] # Output: X_next → self.x_out_bufs[li] @@ -395,6 +407,16 @@ class CUDAGraphDecoder: self.graphs_b[li] = graph_b + # Verify Graph B capture: replay immediately and check output is non-zero + if li < 3 or (li + 1) % 20 == 0: + torch.cuda.set_device(gpu) + graph_b.replay() + torch.cuda.synchronize() + xo_max = self.x_out_bufs[li].abs().max().item() + print(f" L{li} GraphB verify: x_out |X|={xo_max:.4f}", flush=True) + if xo_max == 0.0: + print(f" *** L{li} GraphB VERIFY FAILED: x_out is all zeros! ***", flush=True) + if (li + 1) % 10 == 0: print(f" Captured {li+1}/{self.n_layers} layer A/B graphs", flush=True)