Files
nvfp4-megamoe-kernel/tests/unit/test_cuda_graph_stream.py

79 lines
2.4 KiB
Python

"""Minimal CUDA graph test with explicit stream management."""
import torch
def test_explicit_stream():
"""Test CUDA graph with explicit per-device streams."""
results = {}
for gpu in range(8):
device = f'cuda:{gpu}'
# Create a dedicated stream for this device
s = torch.cuda.Stream(device=device)
# Create tensors on the correct device
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
# Capture on the explicit stream
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=s):
y.copy_(x * 2.0)
# Update input
x.fill_(3.0)
# Replay on the SAME stream
with torch.cuda.stream(s):
g.replay()
torch.cuda.synchronize()
y_max = y.abs().max().item()
expected = 6.0
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
results[gpu] = y_max
print(f" GPU {gpu}: y_max={y_max:.2f}{status}")
return results
def test_set_device_before_each_op():
"""Test with explicit set_device before each operation."""
results = {}
for gpu in range(8):
torch.cuda.set_device(gpu)
device = f'cuda:{gpu}'
x = torch.ones(1, 4, 7168, dtype=torch.bfloat16, device=device)
y = torch.zeros(1, 4, 7168, dtype=torch.bfloat16, device=device)
# Use default stream on the current device
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
# Explicitly set device INSIDE the graph capture
torch.cuda.set_device(gpu)
y.copy_(x * 2.0)
# Update input
x.fill_(3.0)
# Replay
torch.cuda.set_device(gpu)
g.replay()
torch.cuda.synchronize()
y_max = y.abs().max().item()
expected = 6.0
status = "OK" if abs(y_max - expected) < 0.1 else f"WRONG (expected {expected}, got {y_max})"
results[gpu] = y_max
print(f" GPU {gpu}: y_max={y_max:.2f}{status}")
return results
if __name__ == "__main__":
print("=== Test with explicit stream ===")
test_explicit_stream()
print("\n=== Test with set_device inside capture ===")
test_set_device_before_each_op()
print("\nDone.")