115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
|
|
"""Minimal CUDA graph test: verify graph capture works on all 8 B200 GPUs."""
|
||
|
|
import torch
|
||
|
|
|
||
|
|
def test_basic_graph():
|
||
|
|
"""Test basic CUDA graph on each GPU."""
|
||
|
|
results = {}
|
||
|
|
for gpu in range(8):
|
||
|
|
torch.cuda.set_device(gpu)
|
||
|
|
device = f'cuda:{gpu}'
|
||
|
|
|
||
|
|
# Create input and output tensors
|
||
|
|
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
|
||
|
|
# Capture graph
|
||
|
|
g = torch.cuda.CUDAGraph()
|
||
|
|
with torch.cuda.graph(g):
|
||
|
|
y.copy_(x * 2.0)
|
||
|
|
|
||
|
|
# Reset input
|
||
|
|
x.zero_()
|
||
|
|
|
||
|
|
# Replay graph — y should be 0.0 * 2.0 = 0.0 since x is now zero
|
||
|
|
g.replay()
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
y_max = y.abs().max().item()
|
||
|
|
results[gpu] = y_max
|
||
|
|
status = "OK" if y_max == 0.0 else f"WRONG (expected 0.0, got {y_max})"
|
||
|
|
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def test_graph_with_updated_input():
|
||
|
|
"""Test that graph replay uses current data in input buffer."""
|
||
|
|
results = {}
|
||
|
|
for gpu in range(8):
|
||
|
|
torch.cuda.set_device(gpu)
|
||
|
|
device = f'cuda:{gpu}'
|
||
|
|
|
||
|
|
# Create input and output tensors (pre-allocated)
|
||
|
|
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
|
||
|
|
# Fill input with data for capture
|
||
|
|
x_buf.fill_(1.0)
|
||
|
|
|
||
|
|
# Capture graph
|
||
|
|
g = torch.cuda.CUDAGraph()
|
||
|
|
with torch.cuda.graph(g):
|
||
|
|
y_buf.copy_(x_buf * 2.0)
|
||
|
|
|
||
|
|
# Now update input with DIFFERENT data
|
||
|
|
x_buf.fill_(3.0)
|
||
|
|
|
||
|
|
# Replay graph — y should be 3.0 * 2.0 = 6.0
|
||
|
|
g.replay()
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
y_max = y_buf.abs().max().item()
|
||
|
|
results[gpu] = y_max
|
||
|
|
status = "OK" if abs(y_max - 6.0) < 0.1 else f"WRONG (expected 6.0, got {y_max})"
|
||
|
|
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def test_cross_gpu_copy_then_graph():
|
||
|
|
"""Test cross-GPU copy followed by graph replay."""
|
||
|
|
results = {}
|
||
|
|
for gpu in range(1, 8): # Skip GPU 0 (source)
|
||
|
|
torch.cuda.set_device(gpu)
|
||
|
|
device = f'cuda:{gpu}'
|
||
|
|
|
||
|
|
# Source data on cuda:0
|
||
|
|
src = torch.full((1, 4, 7168), 5.0, dtype=torch.bfloat16, device='cuda:0')
|
||
|
|
|
||
|
|
# Input/output buffers on cuda:{gpu}
|
||
|
|
x_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
y_buf = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||
|
|
|
||
|
|
# Fill with data for capture
|
||
|
|
x_buf.fill_(1.0)
|
||
|
|
|
||
|
|
# Capture graph
|
||
|
|
g = torch.cuda.CUDAGraph()
|
||
|
|
with torch.cuda.graph(g):
|
||
|
|
y_buf.copy_(x_buf * 2.0)
|
||
|
|
|
||
|
|
# Copy data from cuda:0 to input buffer
|
||
|
|
x_buf.copy_(src)
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
# Replay — y should be 5.0 * 2.0 = 10.0
|
||
|
|
g.replay()
|
||
|
|
torch.cuda.synchronize()
|
||
|
|
|
||
|
|
y_max = y_buf.abs().max().item()
|
||
|
|
results[gpu] = y_max
|
||
|
|
status = "OK" if abs(y_max - 10.0) < 0.1 else f"WRONG (expected 10.0, got {y_max})"
|
||
|
|
print(f" cuda:0→cuda:{gpu}: y_max={y_max:.2f} — {status}")
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
print("=== Test 1: Basic graph on each GPU ===")
|
||
|
|
test_basic_graph()
|
||
|
|
|
||
|
|
print("\n=== Test 2: Graph replay with updated input ===")
|
||
|
|
test_graph_with_updated_input()
|
||
|
|
|
||
|
|
print("\n=== Test 3: Cross-GPU copy then graph replay ===")
|
||
|
|
test_cross_gpu_copy_then_graph()
|
||
|
|
|
||
|
|
print("\nDone.")
|