"""Minimal CUDA graph test: verify graph capture works on all 8 B200 GPUs.""" import torch def test_basic_graph(): """Test basic CUDA graph on each GPU.""" results = {} for gpu in range(8): torch.cuda.set_device(gpu) device = f'cuda:{gpu}' # Create input and output tensors x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) # Capture graph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): y.copy_(x * 2.0) # Reset input x.zero_() # Replay graph — y should be 0.0 * 2.0 = 0.0 since x is now zero g.replay() torch.cuda.synchronize() y_max = y.abs().max().item() results[gpu] = y_max status = "OK" if y_max == 0.0 else f"WRONG (expected 0.0, got {y_max})" print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") return results def test_graph_with_updated_input(): """Test that graph replay uses current data in input buffer.""" results = {} for gpu in range(8): torch.cuda.set_device(gpu) device = f'cuda:{gpu}' # Create input and output tensors (pre-allocated) x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) # Fill input with data for capture x_buf.fill_(1.0) # Capture graph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): y_buf.copy_(x_buf * 2.0) # Now update input with DIFFERENT data x_buf.fill_(3.0) # Replay graph — y should be 3.0 * 2.0 = 6.0 g.replay() torch.cuda.synchronize() y_max = y_buf.abs().max().item() results[gpu] = y_max status = "OK" if abs(y_max - 6.0) < 0.1 else f"WRONG (expected 6.0, got {y_max})" print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") return results def test_cross_gpu_copy_then_graph(): """Test cross-GPU copy followed by graph replay.""" results = {} for gpu in range(1, 8): # Skip GPU 0 (source) torch.cuda.set_device(gpu) device = f'cuda:{gpu}' # Source data on cuda:0 src = torch.full((1, 4, 7168), 5.0, dtype=torch.bfloat16, device='cuda:0') # Input/output buffers on cuda:{gpu} x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) # Fill with data for capture x_buf.fill_(1.0) # Capture graph g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): y_buf.copy_(x_buf * 2.0) # Copy data from cuda:0 to input buffer x_buf.copy_(src) torch.cuda.synchronize() # Replay — y should be 5.0 * 2.0 = 10.0 g.replay() torch.cuda.synchronize() y_max = y_buf.abs().max().item() results[gpu] = y_max status = "OK" if abs(y_max - 10.0) < 0.1 else f"WRONG (expected 10.0, got {y_max})" print(f" cuda:0→cuda:{gpu}: y_max={y_max:.2f} — {status}") return results if __name__ == "__main__": print("=== Test 1: Basic graph on each GPU ===") test_basic_graph() print("\n=== Test 2: Graph replay with updated input ===") test_graph_with_updated_input() print("\n=== Test 3: Cross-GPU copy then graph replay ===") test_cross_gpu_copy_then_graph() print("\nDone.")