diff --git a/tests/unit/test_cuda_graph_multi_gpu.py b/tests/unit/test_cuda_graph_multi_gpu.py new file mode 100644 index 00000000..182a8e4f --- /dev/null +++ b/tests/unit/test_cuda_graph_multi_gpu.py @@ -0,0 +1,114 @@ +"""Minimal CUDA graph test: verify graph capture works on all 8 B200 GPUs.""" +import torch + +def test_basic_graph(): + """Test basic CUDA graph on each GPU.""" + results = {} + for gpu in range(8): + torch.cuda.set_device(gpu) + device = f'cuda:{gpu}' + + # Create input and output tensors + x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) + y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + + # Capture graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y.copy_(x * 2.0) + + # Reset input + x.zero_() + + # Replay graph — y should be 0.0 * 2.0 = 0.0 since x is now zero + g.replay() + torch.cuda.synchronize() + + y_max = y.abs().max().item() + results[gpu] = y_max + status = "OK" if y_max == 0.0 else f"WRONG (expected 0.0, got {y_max})" + print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") + + return results + +def test_graph_with_updated_input(): + """Test that graph replay uses current data in input buffer.""" + results = {} + for gpu in range(8): + torch.cuda.set_device(gpu) + device = f'cuda:{gpu}' + + # Create input and output tensors (pre-allocated) + x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + + # Fill input with data for capture + x_buf.fill_(1.0) + + # Capture graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_buf.copy_(x_buf * 2.0) + + # Now update input with DIFFERENT data + x_buf.fill_(3.0) + + # Replay graph — y should be 3.0 * 2.0 = 6.0 + g.replay() + torch.cuda.synchronize() + + y_max = y_buf.abs().max().item() + results[gpu] = y_max + status = "OK" if abs(y_max - 6.0) < 0.1 else f"WRONG (expected 6.0, got {y_max})" + print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") + + return results + +def test_cross_gpu_copy_then_graph(): + """Test cross-GPU copy followed by graph replay.""" + results = {} + for gpu in range(1, 8): # Skip GPU 0 (source) + torch.cuda.set_device(gpu) + device = f'cuda:{gpu}' + + # Source data on cuda:0 + src = torch.full((1, 4, 7168), 5.0, dtype=torch.bfloat16, device='cuda:0') + + # Input/output buffers on cuda:{gpu} + x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + + # Fill with data for capture + x_buf.fill_(1.0) + + # Capture graph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_buf.copy_(x_buf * 2.0) + + # Copy data from cuda:0 to input buffer + x_buf.copy_(src) + torch.cuda.synchronize() + + # Replay — y should be 5.0 * 2.0 = 10.0 + g.replay() + torch.cuda.synchronize() + + y_max = y_buf.abs().max().item() + results[gpu] = y_max + status = "OK" if abs(y_max - 10.0) < 0.1 else f"WRONG (expected 10.0, got {y_max})" + print(f" cuda:0→cuda:{gpu}: y_max={y_max:.2f} — {status}") + + return results + +if __name__ == "__main__": + print("=== Test 1: Basic graph on each GPU ===") + test_basic_graph() + + print("\n=== Test 2: Graph replay with updated input ===") + test_graph_with_updated_input() + + print("\n=== Test 3: Cross-GPU copy then graph replay ===") + test_cross_gpu_copy_then_graph() + + print("\nDone.")