"""Minimal CUDA graph test with explicit stream management.""" import torch def test_explicit_stream(): """Test CUDA graph with explicit per-device streams.""" results = {} for gpu in range(8): device = f'cuda:{gpu}' # Create a dedicated stream for this device s = torch.cuda.Stream(device=device) # Create tensors on the correct device x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) # Capture on the explicit stream g = torch.cuda.CUDAGraph() with torch.cuda.graph(g, stream=s): y.copy_(x * 2.0) # Update input x.fill_(3.0) # Replay on the SAME stream with torch.cuda.stream(s): g.replay() torch.cuda.synchronize() y_max = y.abs().max().item() expected = 6.0 status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})" results[gpu] = y_max print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") return results def test_set_device_before_each_op(): """Test with explicit set_device before each operation.""" results = {} for gpu in range(8): torch.cuda.set_device(gpu) device = f'cuda:{gpu}' x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) # Use default stream on the current device g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): # Explicitly set device INSIDE the graph capture torch.cuda.set_device(gpu) y.copy_(x * 2.0) # Update input x.fill_(3.0) # Replay torch.cuda.set_device(gpu) g.replay() torch.cuda.synchronize() y_max = y.abs().max().item() expected = 6.0 status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})" results[gpu] = y_max print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") return results if __name__ == "__main__": print("=== Test with explicit stream ===") test_explicit_stream() print("\n=== Test with set_device inside capture ===") test_set_device_before_each_op() print("\nDone.")