Add minimal CUDA graph multi-GPU test to isolate zero-output bug
This commit is contained in:
114
tests/unit/test_cuda_graph_multi_gpu.py
Normal file
114
tests/unit/test_cuda_graph_multi_gpu.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Minimal CUDA graph test: verify graph capture works on all 8 B200 GPUs."""
|
||||
import torch
|
||||
|
||||
def test_basic_graph():
|
||||
"""Test basic CUDA graph on each GPU."""
|
||||
results = {}
|
||||
for gpu in range(8):
|
||||
torch.cuda.set_device(gpu)
|
||||
device = f'cuda:{gpu}'
|
||||
|
||||
# Create input and output tensors
|
||||
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Capture graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
y.copy_(x * 2.0)
|
||||
|
||||
# Reset input
|
||||
x.zero_()
|
||||
|
||||
# Replay graph — y should be 0.0 * 2.0 = 0.0 since x is now zero
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
y_max = y.abs().max().item()
|
||||
results[gpu] = y_max
|
||||
status = "OK" if y_max == 0.0 else f"WRONG (expected 0.0, got {y_max})"
|
||||
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||||
|
||||
return results
|
||||
|
||||
def test_graph_with_updated_input():
|
||||
"""Test that graph replay uses current data in input buffer."""
|
||||
results = {}
|
||||
for gpu in range(8):
|
||||
torch.cuda.set_device(gpu)
|
||||
device = f'cuda:{gpu}'
|
||||
|
||||
# Create input and output tensors (pre-allocated)
|
||||
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Fill input with data for capture
|
||||
x_buf.fill_(1.0)
|
||||
|
||||
# Capture graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
y_buf.copy_(x_buf * 2.0)
|
||||
|
||||
# Now update input with DIFFERENT data
|
||||
x_buf.fill_(3.0)
|
||||
|
||||
# Replay graph — y should be 3.0 * 2.0 = 6.0
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
y_max = y_buf.abs().max().item()
|
||||
results[gpu] = y_max
|
||||
status = "OK" if abs(y_max - 6.0) < 0.1 else f"WRONG (expected 6.0, got {y_max})"
|
||||
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||||
|
||||
return results
|
||||
|
||||
def test_cross_gpu_copy_then_graph():
|
||||
"""Test cross-GPU copy followed by graph replay."""
|
||||
results = {}
|
||||
for gpu in range(1, 8): # Skip GPU 0 (source)
|
||||
torch.cuda.set_device(gpu)
|
||||
device = f'cuda:{gpu}'
|
||||
|
||||
# Source data on cuda:0
|
||||
src = torch.full((1, 4, 7168), 5.0, dtype=torch.bfloat16, device='cuda:0')
|
||||
|
||||
# Input/output buffers on cuda:{gpu}
|
||||
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Fill with data for capture
|
||||
x_buf.fill_(1.0)
|
||||
|
||||
# Capture graph
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
y_buf.copy_(x_buf * 2.0)
|
||||
|
||||
# Copy data from cuda:0 to input buffer
|
||||
x_buf.copy_(src)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Replay — y should be 5.0 * 2.0 = 10.0
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
y_max = y_buf.abs().max().item()
|
||||
results[gpu] = y_max
|
||||
status = "OK" if abs(y_max - 10.0) < 0.1 else f"WRONG (expected 10.0, got {y_max})"
|
||||
print(f" cuda:0→cuda:{gpu}: y_max={y_max:.2f} — {status}")
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== Test 1: Basic graph on each GPU ===")
|
||||
test_basic_graph()
|
||||
|
||||
print("\n=== Test 2: Graph replay with updated input ===")
|
||||
test_graph_with_updated_input()
|
||||
|
||||
print("\n=== Test 3: Cross-GPU copy then graph replay ===")
|
||||
test_cross_gpu_copy_then_graph()
|
||||
|
||||
print("\nDone.")
|
||||
Reference in New Issue
Block a user