From 2cbf7a43e9ef4b80bcb07fd0f638768af932fb54 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 6 Jun 2026 07:51:22 +0000 Subject: [PATCH] Add sync after cross-GPU copy before graph replay; remove misleading zero-input verification --- single_shot_inference.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 04dbf2fd..4e6e6d8d 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -357,17 +357,9 @@ class CUDAGraphDecoder: self.graphs_a[li] = graph_a - # Verify Graph A capture: replay immediately and check output is non-zero - # This catches issues like wrong device, stale data, or broken kernel arguments - if li < 3 or (li + 1) % 20 == 0: - torch.cuda.set_device(gpu) - graph_a.replay() - torch.cuda.synchronize() - xn_max = self.x_normed_bufs[li].abs().max().item() - qh_max = self.q_heads_bufs[li].abs().max().item() - print(f" L{li} GraphA verify: x_normed |X|={xn_max:.4f} q_heads |X|={qh_max:.4f}", flush=True) - if xn_max == 0.0: - print(f" *** L{li} GraphA VERIFY FAILED: x_normed is all zeros! ***", flush=True) + # Note: We don't verify here because x_in_bufs[li] was zero-initialized. + # The actual replay path populates x_in_bufs via copy_() before replay, + # so the graph replay works correctly with real data. # ======== Graph B: post-attention + FFN compute ======== # Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li] @@ -407,16 +399,6 @@ class CUDAGraphDecoder: self.graphs_b[li] = graph_b - # Verify Graph B capture: replay immediately and check output is non-zero - if li < 3 or (li + 1) % 20 == 0: - torch.cuda.set_device(gpu) - graph_b.replay() - torch.cuda.synchronize() - xo_max = self.x_out_bufs[li].abs().max().item() - print(f" L{li} GraphB verify: x_out |X|={xo_max:.4f}", flush=True) - if xo_max == 0.0: - print(f" *** L{li} GraphB VERIFY FAILED: x_out is all zeros! ***", flush=True) - if (li + 1) % 10 == 0: print(f" Captured {li+1}/{self.n_layers} layer A/B graphs", flush=True) @@ -1903,6 +1885,11 @@ def main(): # Copy X into graph A input buffer (copy_ handles cross-GPU transfer) graph_decoder.x_in_bufs[li].copy_(X) + # Synchronize to ensure cross-GPU copy completes before graph replay + # This is necessary because copy_() between devices may use a copy stream, + # and graph replay must not start until the input data is ready. + if X.device != graph_decoder.x_in_bufs[li].device: + torch.cuda.synchronize() # DEBUG: check input is non-zero (first 3 steps, first 3 layers) if step < 3 and li < 3: