Add CUDA graph stream management test
This commit is contained in:
78
tests/unit/test_cuda_graph_stream.py
Normal file
78
tests/unit/test_cuda_graph_stream.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Minimal CUDA graph test with explicit stream management."""
|
||||
import torch
|
||||
|
||||
def test_explicit_stream():
|
||||
"""Test CUDA graph with explicit per-device streams."""
|
||||
results = {}
|
||||
for gpu in range(8):
|
||||
device = f'cuda:{gpu}'
|
||||
|
||||
# Create a dedicated stream for this device
|
||||
s = torch.cuda.Stream(device=device)
|
||||
|
||||
# Create tensors on the correct device
|
||||
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Capture on the explicit stream
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g, stream=s):
|
||||
y.copy_(x * 2.0)
|
||||
|
||||
# Update input
|
||||
x.fill_(3.0)
|
||||
|
||||
# Replay on the SAME stream
|
||||
with torch.cuda.stream(s):
|
||||
g.replay()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
y_max = y.abs().max().item()
|
||||
expected = 6.0
|
||||
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
|
||||
results[gpu] = y_max
|
||||
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||||
|
||||
return results
|
||||
|
||||
def test_set_device_before_each_op():
|
||||
"""Test with explicit set_device before each operation."""
|
||||
results = {}
|
||||
for gpu in range(8):
|
||||
torch.cuda.set_device(gpu)
|
||||
device = f'cuda:{gpu}'
|
||||
|
||||
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
|
||||
|
||||
# Use default stream on the current device
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
# Explicitly set device INSIDE the graph capture
|
||||
torch.cuda.set_device(gpu)
|
||||
y.copy_(x * 2.0)
|
||||
|
||||
# Update input
|
||||
x.fill_(3.0)
|
||||
|
||||
# Replay
|
||||
torch.cuda.set_device(gpu)
|
||||
g.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
y_max = y.abs().max().item()
|
||||
expected = 6.0
|
||||
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
|
||||
results[gpu] = y_max
|
||||
print(f" GPU {gpu}: y_max={y_max:.2f} — {status}")
|
||||
|
||||
return results
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== Test with explicit stream ===")
|
||||
test_explicit_stream()
|
||||
|
||||
print("\n=== Test with set_device inside capture ===")
|
||||
test_set_device_before_each_op()
|
||||
|
||||
print("\nDone.")
|
||||
Reference in New Issue
Block a user