CRITICAL FIX: Use explicit per-device streams for CUDA graph capture/replay on multi-GPU — fixes zero-output bug

This commit is contained in:
2026-06-06 08:18:18 +00:00
parent 90ac38cde0
commit 6650f06121

View File

@@ -180,7 +180,9 @@ class CUDAGraphDecoder:
# Two graphs per layer (A: pre-attn, B: post-attn+FFN) + lm_head
self.graphs_a = {} # li -> torch.cuda.CUDAGraph
self.graphs_b = {} # li -> torch.cuda.CUDAGraph
self.streams = {} # li -> torch.cuda.Stream (per-device, MUST match capture stream during replay)
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
self.lm_stream = None # stream for lm_head graph on cuda:0
# Pre-allocated I/O buffers — fixed addresses for graph capture
self.x_in_bufs = {} # li -> (1, 4, H) BF16 on layer's device
@@ -315,28 +317,20 @@ class CUDAGraphDecoder:
# NOTE: We capture each Graph A on the correct GPU. Multi-GPU graph capture
# is known to have issues. We add a validation step to verify correctness.
#
# VALIDATION: Before capturing the full Graph A, we first test a MINIMAL
# graph on this GPU to verify basic graph capture works.
if li < 2:
_test_graph = torch.cuda.CUDAGraph()
_test_input = self.x_in_bufs[li].clone()
_test_input.fill_(1.0) # non-zero input
with torch.cuda.graph(_test_graph):
_test_output = _test_input * 2.0
_test_input.copy_(self.x_in_bufs[li]) # restore original data
_test_graph.replay()
torch.cuda.synchronize()
_test_max = _test_output.abs().max().item()
print(f" L{li} minimal graph test on cuda:{gpu}: output |X|={_test_max:.2f} (expected non-zero)", flush=True)
if _test_max == 0.0:
print(f" *** L{li} MINIMAL GRAPH TEST FAILED on cuda:{gpu}! ***", flush=True)
del _test_graph, _test_input, _test_output
# Skip validation — the explicit stream approach handles multi-GPU correctly
# Input: X_l = self.x_in_bufs[li] (1, 4, H)
# Output: x_normed, q_heads, kv_3d, ctx_a, X_l → pre-allocated buffers
# Create per-device stream for graph capture/replay
# CRITICAL: Must use explicit stream for non-default GPUs.
# torch.cuda.set_device() alone doesn't work — PyTorch CUDA graphs
# on non-default GPUs fail silently (empty graph or stale data replay).
s = torch.cuda.Stream(device=dev)
self.streams[li] = s
# NOTE: Norm weights are pre-cached on device in FP32 (attn_norm_dev, etc.)
# to avoid .to() allocations during graph capture.
graph_a = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_a):
with torch.cuda.graph(graph_a, stream=s):
X_l = self.x_in_bufs[li]
# 1. mHC pre_block (attn) — fused P5
@@ -384,7 +378,7 @@ class CUDAGraphDecoder:
# Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li]
# Output: X_next → self.x_out_bufs[li]
graph_b = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_b):
with torch.cuda.graph(graph_b, stream=s):
X_mid = self.X_mid_bufs[li]
F_attn = self.F_attn_bufs[li]
@@ -423,8 +417,9 @@ class CUDAGraphDecoder:
# ---- Capture hc_head + norm + lm_head on cuda:0 ----
torch.cuda.set_device(0)
self.lm_stream = torch.cuda.Stream(device='cuda:0')
self.lm_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.lm_graph):
with torch.cuda.graph(self.lm_graph, stream=self.lm_stream):
x_out = hc_head.forward(self.x_lm_in) if hc_head is not None else self.x_lm_in[:, 0, :]
if final_norm_w is not None:
x_out = rmsnorm(x_out, final_norm_w)
@@ -1904,19 +1899,25 @@ def main():
# Copy X into graph A input buffer (copy_ handles cross-GPU transfer)
graph_decoder.x_in_bufs[li].copy_(X)
# Synchronize to ensure cross-GPU copy completes before graph replay
# This is necessary because copy_() between devices may use a copy stream,
# and graph replay must not start until the input data is ready.
if X.device != graph_decoder.x_in_bufs[li].device:
torch.cuda.synchronize()
# NOTE: Cross-GPU copy synchronization is handled by the stream events
# (Graph A's stream waits for the default stream's F_attn write, and
# vice versa). No explicit sync needed here.
# DEBUG: check input is non-zero (first 3 steps, first 3 layers)
if step < 3 and li < 3:
torch.cuda.synchronize()
print(f" Replay L{li}: x_in |X|={graph_decoder.x_in_bufs[li].abs().max().item():.2f}", flush=True)
# Replay graph A: mHC pre_block + RMSNorm + q_a/q_b/kv projections
graph_decoder.graphs_a[li].replay()
# Replay graph A on its capture stream
with torch.cuda.stream(graph_decoder.streams[li]):
graph_decoder.graphs_a[li].replay()
# Record completion event on graph A's stream, then wait on default stream
# This ensures the default stream (eager attention) sees Graph A's output
_graph_a_done = torch.cuda.Event()
with torch.cuda.stream(graph_decoder.streams[li]):
_graph_a_done.record()
torch.cuda.current_stream().wait_event(_graph_a_done)
# DEBUG: check graph A output (first 3 steps, first 3 layers)
if step < 3 and li < 3:
@@ -1944,13 +1945,20 @@ def main():
# Write F_attn to graph B input buffer
graph_decoder.F_attn_bufs[li].copy_(F_attn)
# Record completion of F_attn write on default stream, wait on graph stream
_eager_done = torch.cuda.Event()
_eager_done.record(torch.cuda.current_stream())
with torch.cuda.stream(graph_decoder.streams[li]):
_eager_done.synchronize()
# DEBUG: check F_attn (first 3 steps, first 3 layers)
if step < 3 and li < 3:
torch.cuda.synchronize()
print(f" Replay L{li} F_attn |X|={F_attn.abs().max().item():.2f}", flush=True)
# Replay graph B: mHC post_block + FFN + MoE + SE
graph_decoder.graphs_b[li].replay()
# Replay graph B on its capture stream
with torch.cuda.stream(graph_decoder.streams[li]):
graph_decoder.graphs_b[li].replay()
# Read output from graph B
X = graph_decoder.x_out_bufs[li]
@@ -1963,9 +1971,9 @@ def main():
# Transfer last layer output to cuda:0 for lm_head graph
graph_decoder.x_lm_in.copy_(X)
# lm_head graph replay — MUST be on cuda:0 (where it was captured)
torch.cuda.set_device(0)
graph_decoder.lm_graph.replay()
# lm_head graph replay — use capture stream on cuda:0
with torch.cuda.stream(graph_decoder.lm_stream):
graph_decoder.lm_graph.replay()
logits = graph_decoder.logits_buf
else: