diff --git a/tests/unit/test_cuda_graph_stream.py b/tests/unit/test_cuda_graph_stream.py new file mode 100644 index 00000000..c0527b97 --- /dev/null +++ b/tests/unit/test_cuda_graph_stream.py @@ -0,0 +1,78 @@ +"""Minimal CUDA graph test with explicit stream management.""" +import torch + +def test_explicit_stream(): + """Test CUDA graph with explicit per-device streams.""" + results = {} + for gpu in range(8): + device = f'cuda:{gpu}' + + # Create a dedicated stream for this device + s = torch.cuda.Stream(device=device) + + # Create tensors on the correct device + x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) + y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + + # Capture on the explicit stream + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=s): + y.copy_(x * 2.0) + + # Update input + x.fill_(3.0) + + # Replay on the SAME stream + with torch.cuda.stream(s): + g.replay() + + torch.cuda.synchronize() + y_max = y.abs().max().item() + expected = 6.0 + status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})" + results[gpu] = y_max + print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") + + return results + +def test_set_device_before_each_op(): + """Test with explicit set_device before each operation.""" + results = {} + for gpu in range(8): + torch.cuda.set_device(gpu) + device = f'cuda:{gpu}' + + x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device) + y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device) + + # Use default stream on the current device + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + # Explicitly set device INSIDE the graph capture + torch.cuda.set_device(gpu) + y.copy_(x * 2.0) + + # Update input + x.fill_(3.0) + + # Replay + torch.cuda.set_device(gpu) + g.replay() + torch.cuda.synchronize() + + y_max = y.abs().max().item() + expected = 6.0 + status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})" + results[gpu] = y_max + print(f" GPU {gpu}: y_max={y_max:.2f} — {status}") + + return results + +if __name__ == "__main__": + print("=== Test with explicit stream ===") + test_explicit_stream() + + print("\n=== Test with set_device inside capture ===") + test_set_device_before_each_op() + + print("\nDone.")