79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
"""Minimal CUDA graph test with explicit stream management."""
|
|
import torch
|
|
|
|
def test_explicit_stream():
|
|
"""Test CUDA graph with explicit per-device streams."""
|
|
results = {}
|
|
for gpu in range(8):
|
|
device = f'cuda:{gpu}'
|
|
|
|
# Create a dedicated stream for this device
|
|
s = torch.cuda.Stream(device=device)
|
|
|
|
# Create tensors on the correct device
|
|
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
|
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
|
|
|
# Capture on the explicit stream
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g, stream=s):
|
|
y.copy_(x * 2.0)
|
|
|
|
# Update input
|
|
x.fill_(3.0)
|
|
|
|
# Replay on the SAME stream
|
|
with torch.cuda.stream(s):
|
|
g.replay()
|
|
|
|
torch.cuda.synchronize()
|
|
y_max = y.abs().max().item()
|
|
expected = 6.0
|
|
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
|
|
results[gpu] = y_max
|
|
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
|
|
|
return results
|
|
|
|
def test_set_device_before_each_op():
|
|
"""Test with explicit set_device before each operation."""
|
|
results = {}
|
|
for gpu in range(8):
|
|
torch.cuda.set_device(gpu)
|
|
device = f'cuda:{gpu}'
|
|
|
|
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
|
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
|
|
|
# Use default stream on the current device
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
# Explicitly set device INSIDE the graph capture
|
|
torch.cuda.set_device(gpu)
|
|
y.copy_(x * 2.0)
|
|
|
|
# Update input
|
|
x.fill_(3.0)
|
|
|
|
# Replay
|
|
torch.cuda.set_device(gpu)
|
|
g.replay()
|
|
torch.cuda.synchronize()
|
|
|
|
y_max = y.abs().max().item()
|
|
expected = 6.0
|
|
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
|
|
results[gpu] = y_max
|
|
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
|
|
|
return results
|
|
|
|
if __name__ == "__main__":
|
|
print("=== Test with explicit stream ===")
|
|
test_explicit_stream()
|
|
|
|
print("\n=== Test with set_device inside capture ===")
|
|
test_set_device_before_each_op()
|
|
|
|
print("\nDone.")
|