Add minimal CUDA graph multi-GPU test to isolate zero-output bug

This commit is contained in:
2026-06-06 08:13:18 +00:00
parent 86275851d4
commit 26042e3f01

View File

@@ -0,0 +1,114 @@
"""Minimal CUDA graph test: verify graph capture works on all 8 B200 GPUs."""
import torch
def test_basic_graph():
"""Test basic CUDA graph on each GPU."""
results = {}
for gpu in range(8):
torch.cuda.set_device(gpu)
device = f'cuda:{gpu}'
# Create input and output tensors
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y.copy_(x * 2.0)
# Reset input
x.zero_()
# Replay graph — y should be 0.0 * 2.0 = 0.0 since x is now zero
g.replay()
torch.cuda.synchronize()
y_max = y.abs().max().item()
results[gpu] = y_max
status = "OK" if y_max == 0.0 else f"WRONG (expected 0.0, got {y_max})"
print(f" GPU {gpu}: y_max={y_max:.2f}{status}")
return results
def test_graph_with_updated_input():
"""Test that graph replay uses current data in input buffer."""
results = {}
for gpu in range(8):
torch.cuda.set_device(gpu)
device = f'cuda:{gpu}'
# Create input and output tensors (pre-allocated)
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
# Fill input with data for capture
x_buf.fill_(1.0)
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_buf.copy_(x_buf * 2.0)
# Now update input with DIFFERENT data
x_buf.fill_(3.0)
# Replay graph — y should be 3.0 * 2.0 = 6.0
g.replay()
torch.cuda.synchronize()
y_max = y_buf.abs().max().item()
results[gpu] = y_max
status = "OK" if abs(y_max - 6.0) < 0.1 else f"WRONG (expected 6.0, got {y_max})"
print(f" GPU {gpu}: y_max={y_max:.2f}{status}")
return results
def test_cross_gpu_copy_then_graph():
"""Test cross-GPU copy followed by graph replay."""
results = {}
for gpu in range(1, 8): # Skip GPU 0 (source)
torch.cuda.set_device(gpu)
device = f'cuda:{gpu}'
# Source data on cuda:0
src = torch.full((1, 4, 7168), 5.0, dtype=torch.bfloat16, device='cuda:0')
# Input/output buffers on cuda:{gpu}
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
# Fill with data for capture
x_buf.fill_(1.0)
# Capture graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
y_buf.copy_(x_buf * 2.0)
# Copy data from cuda:0 to input buffer
x_buf.copy_(src)
torch.cuda.synchronize()
# Replay — y should be 5.0 * 2.0 = 10.0
g.replay()
torch.cuda.synchronize()
y_max = y_buf.abs().max().item()
results[gpu] = y_max
status = "OK" if abs(y_max - 10.0) < 0.1 else f"WRONG (expected 10.0, got {y_max})"
print(f" cuda:0→cuda:{gpu}: y_max={y_max:.2f}{status}")
return results
if __name__ == "__main__":
print("=== Test 1: Basic graph on each GPU ===")
test_basic_graph()
print("\n=== Test 2: Graph replay with updated input ===")
test_graph_with_updated_input()
print("\n=== Test 3: Cross-GPU copy then graph replay ===")
test_cross_gpu_copy_then_graph()
print("\nDone.")