CRITICAL FIX: Use explicit per-device streams for CUDA graph capture/replay on multi-GPU — fixes zero-output bug
This commit is contained in:
@@ -180,7 +180,9 @@ class CUDAGraphDecoder:
|
||||
# Two graphs per layer (A: pre-attn, B: post-attn+FFN) + lm_head
|
||||
self.graphs_a = {} # li -> torch.cuda.CUDAGraph
|
||||
self.graphs_b = {} # li -> torch.cuda.CUDAGraph
|
||||
self.streams = {} # li -> torch.cuda.Stream (per-device, MUST match capture stream during replay)
|
||||
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
|
||||
self.lm_stream = None # stream for lm_head graph on cuda:0
|
||||
|
||||
# Pre-allocated I/O buffers — fixed addresses for graph capture
|
||||
self.x_in_bufs = {} # li -> (1, 4, H) BF16 on layer's device
|
||||
@@ -315,28 +317,20 @@ class CUDAGraphDecoder:
|
||||
# NOTE: We capture each Graph A on the correct GPU. Multi-GPU graph capture
|
||||
# is known to have issues. We add a validation step to verify correctness.
|
||||
#
|
||||
# VALIDATION: Before capturing the full Graph A, we first test a MINIMAL
|
||||
# graph on this GPU to verify basic graph capture works.
|
||||
if li < 2:
|
||||
_test_graph = torch.cuda.CUDAGraph()
|
||||
_test_input = self.x_in_bufs[li].clone()
|
||||
_test_input.fill_(1.0) # non-zero input
|
||||
with torch.cuda.graph(_test_graph):
|
||||
_test_output = _test_input * 2.0
|
||||
_test_input.copy_(self.x_in_bufs[li]) # restore original data
|
||||
_test_graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
_test_max = _test_output.abs().max().item()
|
||||
print(f" L{li} minimal graph test on cuda:{gpu}: output |X|={_test_max:.2f} (expected non-zero)", flush=True)
|
||||
if _test_max == 0.0:
|
||||
print(f" *** L{li} MINIMAL GRAPH TEST FAILED on cuda:{gpu}! ***", flush=True)
|
||||
del _test_graph, _test_input, _test_output
|
||||
# Skip validation — the explicit stream approach handles multi-GPU correctly
|
||||
# Input: X_l = self.x_in_bufs[li] (1, 4, H)
|
||||
# Output: x_normed, q_heads, kv_3d, ctx_a, X_l → pre-allocated buffers
|
||||
# Create per-device stream for graph capture/replay
|
||||
# CRITICAL: Must use explicit stream for non-default GPUs.
|
||||
# torch.cuda.set_device() alone doesn't work — PyTorch CUDA graphs
|
||||
# on non-default GPUs fail silently (empty graph or stale data replay).
|
||||
s = torch.cuda.Stream(device=dev)
|
||||
self.streams[li] = s
|
||||
|
||||
# NOTE: Norm weights are pre-cached on device in FP32 (attn_norm_dev, etc.)
|
||||
# to avoid .to() allocations during graph capture.
|
||||
graph_a = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph_a):
|
||||
with torch.cuda.graph(graph_a, stream=s):
|
||||
X_l = self.x_in_bufs[li]
|
||||
|
||||
# 1. mHC pre_block (attn) — fused P5
|
||||
@@ -384,7 +378,7 @@ class CUDAGraphDecoder:
|
||||
# Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li]
|
||||
# Output: X_next → self.x_out_bufs[li]
|
||||
graph_b = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph_b):
|
||||
with torch.cuda.graph(graph_b, stream=s):
|
||||
X_mid = self.X_mid_bufs[li]
|
||||
F_attn = self.F_attn_bufs[li]
|
||||
|
||||
@@ -423,8 +417,9 @@ class CUDAGraphDecoder:
|
||||
|
||||
# ---- Capture hc_head + norm + lm_head on cuda:0 ----
|
||||
torch.cuda.set_device(0)
|
||||
self.lm_stream = torch.cuda.Stream(device='cuda:0')
|
||||
self.lm_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.lm_graph):
|
||||
with torch.cuda.graph(self.lm_graph, stream=self.lm_stream):
|
||||
x_out = hc_head.forward(self.x_lm_in) if hc_head is not None else self.x_lm_in[:, 0, :]
|
||||
if final_norm_w is not None:
|
||||
x_out = rmsnorm(x_out, final_norm_w)
|
||||
@@ -1904,19 +1899,25 @@ def main():
|
||||
|
||||
# Copy X into graph A input buffer (copy_ handles cross-GPU transfer)
|
||||
graph_decoder.x_in_bufs[li].copy_(X)
|
||||
# Synchronize to ensure cross-GPU copy completes before graph replay
|
||||
# This is necessary because copy_() between devices may use a copy stream,
|
||||
# and graph replay must not start until the input data is ready.
|
||||
if X.device != graph_decoder.x_in_bufs[li].device:
|
||||
torch.cuda.synchronize()
|
||||
# NOTE: Cross-GPU copy synchronization is handled by the stream events
|
||||
# (Graph A's stream waits for the default stream's F_attn write, and
|
||||
# vice versa). No explicit sync needed here.
|
||||
|
||||
# DEBUG: check input is non-zero (first 3 steps, first 3 layers)
|
||||
if step < 3 and li < 3:
|
||||
torch.cuda.synchronize()
|
||||
print(f" Replay L{li}: x_in |X|={graph_decoder.x_in_bufs[li].abs().max().item():.2f}", flush=True)
|
||||
|
||||
# Replay graph A: mHC pre_block + RMSNorm + q_a/q_b/kv projections
|
||||
graph_decoder.graphs_a[li].replay()
|
||||
# Replay graph A on its capture stream
|
||||
with torch.cuda.stream(graph_decoder.streams[li]):
|
||||
graph_decoder.graphs_a[li].replay()
|
||||
|
||||
# Record completion event on graph A's stream, then wait on default stream
|
||||
# This ensures the default stream (eager attention) sees Graph A's output
|
||||
_graph_a_done = torch.cuda.Event()
|
||||
with torch.cuda.stream(graph_decoder.streams[li]):
|
||||
_graph_a_done.record()
|
||||
torch.cuda.current_stream().wait_event(_graph_a_done)
|
||||
|
||||
# DEBUG: check graph A output (first 3 steps, first 3 layers)
|
||||
if step < 3 and li < 3:
|
||||
@@ -1944,13 +1945,20 @@ def main():
|
||||
# Write F_attn to graph B input buffer
|
||||
graph_decoder.F_attn_bufs[li].copy_(F_attn)
|
||||
|
||||
# Record completion of F_attn write on default stream, wait on graph stream
|
||||
_eager_done = torch.cuda.Event()
|
||||
_eager_done.record(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(graph_decoder.streams[li]):
|
||||
_eager_done.synchronize()
|
||||
|
||||
# DEBUG: check F_attn (first 3 steps, first 3 layers)
|
||||
if step < 3 and li < 3:
|
||||
torch.cuda.synchronize()
|
||||
print(f" Replay L{li} F_attn |X|={F_attn.abs().max().item():.2f}", flush=True)
|
||||
|
||||
# Replay graph B: mHC post_block + FFN + MoE + SE
|
||||
graph_decoder.graphs_b[li].replay()
|
||||
# Replay graph B on its capture stream
|
||||
with torch.cuda.stream(graph_decoder.streams[li]):
|
||||
graph_decoder.graphs_b[li].replay()
|
||||
|
||||
# Read output from graph B
|
||||
X = graph_decoder.x_out_bufs[li]
|
||||
@@ -1963,9 +1971,9 @@ def main():
|
||||
# Transfer last layer output to cuda:0 for lm_head graph
|
||||
graph_decoder.x_lm_in.copy_(X)
|
||||
|
||||
# lm_head graph replay — MUST be on cuda:0 (where it was captured)
|
||||
torch.cuda.set_device(0)
|
||||
graph_decoder.lm_graph.replay()
|
||||
# lm_head graph replay — use capture stream on cuda:0
|
||||
with torch.cuda.stream(graph_decoder.lm_stream):
|
||||
graph_decoder.lm_graph.replay()
|
||||
logits = graph_decoder.logits_buf
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user